Skip to content

Commit

Permalink
chore: BED-5198 - merge with mainline before closing the stage branch
Browse files Browse the repository at this point in the history
  • Loading branch information
zinic committed Dec 20, 2024
2 parents 5d6f959 + 4f5d10e commit 4a15da4
Show file tree
Hide file tree
Showing 49 changed files with 1,086 additions and 653 deletions.
85 changes: 71 additions & 14 deletions cmd/api/src/analysis/ad/adcs_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,63 +447,75 @@ func TestTrustedForNTAuth(t *testing.T) {
func TestEnrollOnBehalfOf(t *testing.T) {
testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema())
testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
harness.EnrollOnBehalfOfHarnessOne.Setup(testContext)
harness.EnrollOnBehalfOfHarness1.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate)
v1Templates := make([]*graph.Node, 0)
v2Templates := make([]*graph.Node, 0)

for _, template := range certTemplates {
if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil {
continue
} else if version == 1 {
v1Templates = append(v1Templates, template)
} else if version >= 2 {
continue
v2Templates = append(v2Templates, template)
}
}

require.Nil(t, err)

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates)
results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1)
require.Nil(t, err)

require.Len(t, results, 3)

require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate11.ID,
ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate11.ID,
ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})

require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate13.ID,
ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate13.ID,
ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})

require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})

return nil
})

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1)
require.Nil(t, err)

require.Len(t, results, 0)

return nil
})
})

testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
harness.EnrollOnBehalfOfHarnessTwo.Setup(testContext)
harness.EnrollOnBehalfOfHarness2.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate)
v1Templates := make([]*graph.Node, 0)
v2Templates := make([]*graph.Node, 0)

for _, template := range certTemplates {
if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil {
continue
} else if version == 1 {
continue
v1Templates = append(v1Templates, template)
} else if version >= 2 {
v2Templates = append(v2Templates, template)
}
Expand All @@ -512,15 +524,60 @@ func TestEnrollOnBehalfOf(t *testing.T) {
require.Nil(t, err)

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates)
results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2)
require.Nil(t, err)

require.Len(t, results, 0)
return nil
})

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2)
require.Nil(t, err)

require.Len(t, results, 1)
require.Contains(t, results, analysis.CreatePostRelationshipJob{
FromID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate21.ID,
ToID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate23.ID,
FromID: harness.EnrollOnBehalfOfHarness2.CertTemplate21.ID,
ToID: harness.EnrollOnBehalfOfHarness2.CertTemplate23.ID,
Kind: ad.EnrollOnBehalfOf,
})
return nil
})
})

testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
harness.EnrollOnBehalfOfHarness3.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - EnrollOnBehalfOf 3")

_, enterpriseCertAuthorities, certTemplates, domains, cache, err := FetchADCSPrereqs(db)
require.Nil(t, err)

if err := ad2.PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates, cache, operation); err != nil {
t.Logf("failed post processing for %s: %v", ad.EnrollOnBehalfOf.String(), err)
}
err = operation.Done()
require.Nil(t, err)

db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
if startNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria {
return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf)
})); err != nil {
t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err)
} else if endNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria {
return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf)
})); err != nil {
t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err)
} else {
require.Len(t, startNodes, 2)
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate11))
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))

require.Len(t, endNodes, 2)
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))
require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))
}

return nil
})
Expand Down
3 changes: 3 additions & 0 deletions cmd/api/src/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ const (
ErrorResponsePayloadUnmarshalError = "error unmarshalling JSON payload"
ErrorResponseRequestTimeout = "request timed out"
ErrorResponseUserSelfDisable = "user attempted to disable themselves"
ErrorResponseUserSelfRoleChange = "user attempted to change own role"
ErrorResponseUserSelfSSOProviderChange = "user attempted to change own SSO Provider"
ErrorResponseAGTagWhiteSpace = "asset group tags must not contain whitespace"
ErrorResponseAGNameTagEmpty = "asset group name or tag must not be empty"
ErrorResponseAGDuplicateName = "asset group name must be unique"
ErrorResponseAGDuplicateTag = "asset group tag must be unique"
ErrorResponseUserDuplicatePrincipal = "principal name must be unique"
ErrorResponseDetailsUniqueViolation = "unique constraint was violated"
ErrorResponseDetailsNotImplemented = "All good things to those who wait. Not implemented."

Expand Down
45 changes: 25 additions & 20 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,24 +356,18 @@ func (s ManagementResource) CreateUser(response http.ResponseWriter, request *ht
}

if newUser, err := s.db.CreateUser(request.Context(), userTemplate); err != nil {
api.HandleDatabaseError(request, response, err)
if errors.Is(err, database.ErrDuplicateUserPrincipal) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, api.ErrorResponseUserDuplicatePrincipal, request), response)
} else {
api.HandleDatabaseError(request, response, err)
}
} else {
api.WriteBasicResponse(request.Context(), newUser, http.StatusOK, response)
}

}
}

func (s ManagementResource) ensureUserHasNoAuthSecret(ctx context.Context, user model.User) error {
if user.AuthSecret != nil {
if err := s.db.DeleteAuthSecret(ctx, *user.AuthSecret); err != nil {
return api.FormatDatabaseError(err)
}
}

return nil
}

func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *http.Request) {
var (
updateUserRequest v2.UpdateUserRequest
Expand All @@ -400,8 +394,10 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.PrincipalName = updateUserRequest.Principal
user.IsDisabled = updateUserRequest.IsDisabled

loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx)

if user.IsDisabled {
if loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx); user.ID == loggedInUser.ID {
if user.ID == loggedInUser.ID {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfDisable, request), response)
return
} else if userSessions, err := s.db.LookupActiveSessionsByUser(request.Context(), user); err != nil {
Expand All @@ -419,9 +415,6 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
if samlProviderID, err := serde.ParseInt32(updateUserRequest.SAMLProviderID); err != nil {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, fmt.Sprintf("SAML Provider ID must be a number: %v", err.Error()), request), response)
return
} else if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if provider, err := s.db.GetSAMLProvider(request.Context(), samlProviderID); err != nil {
api.HandleDatabaseError(request, response, err)
return
Expand All @@ -431,10 +424,7 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = provider.SSOProviderID
}
} else if updateUserRequest.SSOProviderID.Valid {
if err := s.ensureUserHasNoAuthSecret(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
if _, err := s.db.GetSSOProviderById(request.Context(), updateUserRequest.SSOProviderID.Int32); err != nil {
api.HandleDatabaseError(request, response, err)
return
} else {
Expand All @@ -447,8 +437,23 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
user.SSOProviderID = null.NewInt32(0, false)
}

// Prevent a user from modifying their own roles/permissions
if user.ID == loggedInUser.ID {
if !slices.Equal(roles.IDs(), loggedInUser.Roles.IDs()) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfRoleChange, request), response)
return
} else if !user.SSOProviderID.Equal(loggedInUser.SSOProviderID) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSelfSSOProviderChange, request), response)
return
}
}

if err := s.db.UpdateUser(request.Context(), user); err != nil {
api.HandleDatabaseError(request, response, err)
if errors.Is(err, database.ErrDuplicateUserPrincipal) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, api.ErrorResponseUserDuplicatePrincipal, request), response)
} else {
api.HandleDatabaseError(request, response, err)
}
} else {
response.WriteHeader(http.StatusOK)
}
Expand Down
66 changes: 65 additions & 1 deletion cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,17 @@ func TestManagementResource_PutUserAuthSecret(t *testing.T) {

func TestManagementResource_EnableUserSAML(t *testing.T) {
var (
adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}}
goodRoles = []int32{0}
goodUserID = must.NewUUIDv4()
badUserID = must.NewUUIDv4()
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
bhCtx.AuthCtx.Owner = adminUser

defer mockCtrl.Finish()

t.Run("Successfully update user with deprecated saml provider", func(t *testing.T) {
Expand All @@ -181,6 +185,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -197,9 +202,9 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().GetUser(gomock.Any(), badUserID).Return(model.User{AuthSecret: &model.AuthSecret{}}, nil)
mockDB.EXPECT().GetSAMLProvider(gomock.Any(), samlProviderID).Return(model.SAMLProvider{}, nil)
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)
mockDB.EXPECT().DeleteAuthSecret(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": badUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand All @@ -218,6 +223,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
mockDB.EXPECT().UpdateUser(gomock.Any(), gomock.Any()).Return(nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Expand Down Expand Up @@ -1511,6 +1517,64 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) {
require.Contains(t, response.Body.String(), api.ErrorResponseUserSelfDisable)
}

func TestManagementResource_UpdateUser_UserSelfModify(t *testing.T) {
var (
adminRole = model.Role{
Serial: model.Serial{
ID: 1,
},
}
goodRoles = []int32{1}
badRole = model.Role{
Serial: model.Serial{
ID: 2,
},
}
badRoles = []int32{2}
adminUser = model.User{AuthSecret: defaultDigestAuthSecret(t, "currentPassword"), Unique: model.Unique{ID: must.NewUUIDv4()}, Roles: model.Roles{adminRole}}
mockCtrl = gomock.NewController(t)
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

bhCtx := ctx.Get(context.WithValue(context.Background(), ctx.ValueKey, &ctx.Context{}))
bhCtx.AuthCtx.Owner = adminUser

defer mockCtrl.Finish()

t.Run("Prevent users from changing their own SSO provider", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{adminRole}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProviderID).Return(model.SSOProvider{}, nil)
test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: goodRoles,
SSOProviderID: null.Int32From(123),
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("Prevent users from changing their own roles", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Any()).Return(model.Roles{badRole}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), adminUser.ID).Return(adminUser, nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": adminUser.ID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: badRoles,
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusBadRequest)
})
}

func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
Expand Down
Loading

0 comments on commit 4a15da4

Please sign in to comment.