Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BED-5060: Prevent user from changing their own role/auth #984

Merged
merged 11 commits into from
Dec 6, 2024
2 changes: 2 additions & 0 deletions cmd/api/src/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const (
ErrorResponsePayloadUnmarshalError = "error unmarshalling JSON payload"
ErrorResponseRequestTimeout = "request timed out"
ErrorResponseUserSelfDisable = "user attempted to disable themselves"
ErrorResponseUserSelfRoleChange = "user attempted to change own role"
ErrorResponseUserSelfSSOChange = "user attempted to change own SSO"
wes-mil marked this conversation as resolved.
Show resolved Hide resolved
ErrorResponseAGTagWhiteSpace = "asset group tags must not contain whitespace"
ErrorResponseAGNameTagEmpty = "asset group name or tag must not be empty"
ErrorResponseAGDuplicateName = "asset group name must be unique"
Expand Down
33 changes: 15 additions & 18 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,6 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht
}
}

func (s ManagementResource) ensureUserHasNoAuthSecret(ctx context.Context, user model.User) error {
if user.AuthSecret != nil {
if err := s.db.DeleteAuthSecret(ctx, *user.AuthSecret); err != nil {
return api.FormatDatabaseError(err)
}
}

return nil
}

func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *http.Request) {
var (
updateUserRequest v2.UpdateUserRequest
Expand All @@ -400,8 +390,10 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.PrincipalName = updateUserRequest.Principal
user.IsDisabled = updateUserRequest.IsDisabled

loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx)

if user.IsDisabled {
if loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx); user.ID == loggedInUser.ID {
if user.ID == loggedInUser.ID {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfDisable, request), response)
return
} else if userSessions, err := s.db.LookupActiveSessionsByUser(request.Context(), user); err != nil {
Expand All @@ -419,9 +411,6 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
if samlProviderID, err := serde.ParseInt32(updateUserRequest.SAMLProviderID); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("SAML Provider ID must be a number: %v", err.Error()), request), response)
return
} else if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if provider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil {
api.HandleDatabaseError(request, response, err)
return
Expand All @@ -431,10 +420,7 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = provider.SSOProviderID
}
} else if updateUserRequest.SSOProviderID.Valid {
if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
Expand All @@ -447,6 +433,17 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = null.NewInt32(0, false)
}

// Prevent a user from modifying their own roles/permissions
if user.ID == loggedInUser.ID {
if !slices.Equal(roles.IDs(), loggedInUser.Roles.IDs()) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfRoleChange, request), response)
return
} else if !user.SSOProviderID.Equal(loggedInUser.SSOProviderID) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfSSOChange, request), response)
return
}
}

if err := s.db.UpdateUser(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
Expand Down
8 changes: 7 additions & 1 deletion cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,17 @@ func TestManagementResource_PutUserAuthSecret(t *testing.T) {

func TestManagementResource_EnableUserSAML(t *testing.T) {
var (
adminUser = model.User{AuthSecret: defaultDigestAuthSecret(t, "currentPassword"), Unique: model.Unique{ID: must.NewUUIDv4()}}
goodRoles = []int32{0}
goodUserID = must.NewUUIDv4()
badUserID = must.NewUUIDv4()
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
bhCtx.AuthCtx.Owner = adminUser

defer mockCtrl.Finish()

t.Run("Successfully update user with deprecated saml provider", func(t *testing.T) {
Expand All @@ -181,6 +185,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -197,9 +202,9 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)
mockDB.EXPECT().DeleteAuthSecret(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": badUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -218,6 +223,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand Down
11 changes: 10 additions & 1 deletion cmd/api/src/database/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ func (s *BloodhoundDB) UpdateUser(ctx context.Context, user model.User) error {
return err
}

// AuthSecret must be manually deleted and then set to nil again to prevent recreation
if user.AuthSecret == nil {
if err := tx.Unscoped().Model(&user).WithContext(ctx).Association("AuthSecret").Unscoped().Clear(); err != nil {
return err
}

user.AuthSecret = nil
}

mistahj67 marked this conversation as resolved.
Show resolved Hide resolved
result := tx.WithContext(ctx).Save(&user)
return CheckError(result)
})
Expand Down Expand Up @@ -431,7 +440,7 @@ func (s *BloodhoundDB) CreateAuthSecret(ctx context.Context, authSecret model.Au
func (s *BloodhoundDB) GetAuthSecret(ctx context.Context, id int32) (model.AuthSecret, error) {
var (
authSecret model.AuthSecret
result = s.db.WithContext(ctx).Find(&authSecret, id)
result = s.db.WithContext(ctx).First(&authSecret, id)
Copy link
Contributor Author

@wes-mil wes-mil Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This database function was not properly returning an error when there was no auth secret with ID

)

return authSecret, CheckError(result)
Expand Down
65 changes: 65 additions & 0 deletions cmd/api/src/database/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,71 @@ func TestDatabase_CreateGetDeleteUser(t *testing.T) {
}
}

func TestDatabase_UpdateUserAuth(t *testing.T) {
var (
ctx = context.Background()
dbInst, user = initAndCreateUser(t)
secret = model.AuthSecret{
UserID: user.ID,
Digest: "digest",
DigestMethod: "fake",
ExpiresAt: time.Now().Add(1 * time.Hour),
}
samlProvider = model.SAMLProvider{
Name: "provider",
DisplayName: "provider name",
IssuerURI: "https://idp.example.com/idp.xml",
SingleSignOnURI: "https://idp.example.com/sso",
}
)

if newSecret, err := dbInst.CreateAuthSecret(ctx, secret); err != nil {
t.Fatalf("Failed to create auth secret: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionCreateAuthSecret, "secret_user_id", newSecret.UserID.String()); err != nil {
t.Fatalf("Failed to validate CreateAuthSecret audit logs:\n%v", err)
} else {
if newSAMLProvider, err := dbInst.CreateSAMLIdentityProvider(ctx, samlProvider); err != nil {
t.Fatalf("Failed to create SAML provider: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionCreateSAMLIdentityProvider, "saml_name", newSAMLProvider.Name); err != nil {
t.Fatalf("Failed to validate CreateSAMLIdentityProvider audit logs:\n%v", err)
} else {
user, err = dbInst.GetUser(ctx, user.ID)
if err != nil {
t.Fatalf("Failed looking up user by principal %s: %v", user.PrincipalName, err)
}

user.FirstName = null.StringFrom("friendly man")

if err := dbInst.UpdateUser(ctx, user); err != nil {
t.Fatalf("Failed to update user: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionUpdateUser, "principal_name", user.PrincipalName); err != nil {
t.Fatalf("Failed to validate UpdateUser audit logs:\n%v", err)
} else if updatedUser, err := dbInst.GetUser(ctx, user.ID); err != nil {
t.Fatalf("Failed looking up user by principal %s: %v", user.PrincipalName, err)
} else if updatedUser.AuthSecret == nil {
t.Fatalf("Failed to find authsecret for user %s", user.PrincipalName)
} else if _, err := dbInst.GetAuthSecret(ctx, updatedUser.AuthSecret.ID); err != nil {
t.Fatalf("Failed to get authsecret by id %d", updatedUser.AuthSecret.ID)
}

user.AuthSecret = nil
user.SSOProviderID = newSAMLProvider.SSOProviderID

if err := dbInst.UpdateUser(ctx, user); err != nil {
t.Fatalf("Failed to update user: %v", err)
} else if err = test.VerifyAuditLogs(dbInst, model.AuditLogActionUpdateUser, "principal_name", user.PrincipalName); err != nil {
t.Fatalf("Failed to validate UpdateUser audit logs:\n%v", err)
} else if updatedUser, err := dbInst.GetUser(ctx, user.ID); err != nil {
t.Fatalf("Failed looking up user by principal %s: %v", user.PrincipalName, err)
} else if updatedUser.AuthSecret != nil {
t.Fatalf("Found authsecret for user %s but expected it to be removed", user.PrincipalName)
} else if _, err := dbInst.GetAuthSecret(ctx, newSecret.ID); err == nil {
t.Fatalf("Found authsecret for id %d but expected it to be removed", newSecret.ID)
}
}
}
}

func TestDatabase_CreateGetDeleteAuthToken(t *testing.T) {
var (
ctx = context.Background()
Expand Down
16 changes: 5 additions & 11 deletions cmd/api/src/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ require (
github.com/go-chi/chi/v5 v5.0.8
github.com/gobeam/stringy v0.0.6
github.com/gofrs/uuid v4.4.0+incompatible
github.com/golang-jwt/jwt/v4 v4.5.0
github.com/golang-jwt/jwt/v4 v4.5.1
github.com/golang/mock v1.6.0
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0
Expand All @@ -44,12 +44,11 @@ require (
github.com/unrolled/secure v1.13.0
go.uber.org/mock v0.2.0
golang.org/x/oauth2 v0.23.0
gorm.io/driver/postgres v1.3.8
gorm.io/gorm v1.23.8
gorm.io/driver/postgres v1.5.10
gorm.io/gorm v1.25.12
)

require (
github.com/Masterminds/semver/v3 v3.2.1 // indirect
github.com/beevik/etree v1.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1 // indirect
Expand All @@ -61,18 +60,12 @@ require (
github.com/go-pkgz/expirable-cache v1.0.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.14.3 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.3 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgtype v1.14.4 // indirect
github.com/jackc/pgx/v4 v4.18.3 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
Expand All @@ -82,6 +75,7 @@ require (
github.com/prometheus/procfs v0.11.0 // indirect
github.com/rivo/uniseg v0.4.4 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/time v0.3.0 // indirect
Expand Down
Loading
Loading