From b5e60d1cf42d0bc11939c51257f68cb9fd0146fb Mon Sep 17 00:00:00 2001 From: dasulin Date: Mon, 18 Nov 2024 00:49:06 +0200 Subject: [PATCH] Change logic in handler,db,events and align rhsso_test to new function headers --- internal/common/db.go | 2 +- internal/events/event.go | 16 +++++++----- pkg/auth/rhsso_authz_handler.go | 20 +++++++++++--- pkg/auth/rhsso_authz_handler_test.go | 39 ++++++++++++++++++---------- pkg/ocm/mock_authorization.go | 9 ++++--- 5 files changed, 58 insertions(+), 28 deletions(-) diff --git a/internal/common/db.go b/internal/common/db.go index cc1f49aa7f6..14e2022e390 100644 --- a/internal/common/db.go +++ b/internal/common/db.go @@ -433,7 +433,7 @@ func GetInfraEnvHostsFromDB(db *gorm.DB, infraEnvID strfmt.UUID) ([]*Host, error func GetInfraEnvsFromDBWhere(db *gorm.DB, where ...interface{}) ([]*InfraEnv, error) { var infraEnvs []*InfraEnv - err := db.Find(&infraEnvs, where...).Error + err := db.Joins("LEFT JOIN clusters on infra_envs.cluster_id = clusters.id").Find(&infraEnvs, where...).Error if err != nil { return nil, err } diff --git a/internal/events/event.go b/internal/events/event.go index 9762408ce13..1a16439a855 100644 --- a/internal/events/event.go +++ b/internal/events/event.go @@ -351,16 +351,19 @@ func (e Events) prepareEventsTable(ctx context.Context, tx *gorm.DB, clusterID * return tx } + //When using rhsso auth, required to have clusters fields and infra_envs fields + tx = tx.Joins("LEFT JOIN infra_envs ON infra_envs.id = events.infra_env_id"). + Joins("LEFT JOIN clusters ON events.cluster_id = clusters.id") + //for bound events that are searched with cluster id (whether on clusters, bound infra-env , //host bound to a cluster or registered to a bound infra-env) check the access permission //relative to the cluster ownership if clusterBoundEvents() { - tx = tx.Model(&common.Event{}).Select("events.*, clusters.user_name, clusters.org_id"). - Joins("INNER JOIN clusters ON clusters.id = events.cluster_id") + tx = tx.Model(&common.Event{}).Select("events.*, clusters.openshift_cluster_id, clusters.user_name, clusters.org_id") // if deleted hosts flag is true, we need to add 'deleted_at' to know whether events are related to a deleted host if swag.BoolValue(deletedHosts) { - tx = tx.Select("events.*, clusters.user_name, clusters.org_id, hosts.deleted_at"). + tx = tx.Select("events.*, clusters.openshift_cluster_id, clusters.user_name, clusters.org_id, hosts.deleted_at"). Joins("LEFT JOIN hosts ON hosts.id = events.host_id") } return tx @@ -369,15 +372,14 @@ func (e Events) prepareEventsTable(ctx context.Context, tx *gorm.DB, clusterID * //for unbound events that are searched with infra-env id (whether events on hosts or the //infra-env level itself) check the access permission relative to the infra-env ownership if nonBoundEvents() { - return tx.Model(&common.Event{}).Select("events.*, infra_envs.user_name, infra_envs.org_id"). - Joins("INNER JOIN infra_envs ON infra_envs.id = events.infra_env_id") + return tx.Model(&common.Event{}). + Select("events.*, infra_envs.user_name, infra_envs.org_id, clusters.openshift_cluster_id") } // Events must be linked to the infra_envs table and then to the hosts table // The hosts table does not hold an org_id, so permissions related fields must be supplied by the infra_env if hostOnlyEvents() { - return tx.Model(&common.Event{}).Select("events.*, infra_envs.user_name, infra_envs.org_id"). - Joins("INNER JOIN infra_envs ON infra_envs.id = events.infra_env_id"). + return tx.Model(&common.Event{}).Select("events.*, infra_envs.user_name, infra_envs.org_id, clusters.openshift_cluster_id"). Joins("INNER JOIN hosts ON hosts.id = events.host_id"). // This join is here to ensure that only events for a host that exists are fetched Where("hosts.deleted_at IS NULL") // Only interested in active hosts } diff --git a/pkg/auth/rhsso_authz_handler.go b/pkg/auth/rhsso_authz_handler.go index 7911cc65678..c6e78812600 100644 --- a/pkg/auth/rhsso_authz_handler.go +++ b/pkg/auth/rhsso_authz_handler.go @@ -80,11 +80,25 @@ func (a *AuthzHandler) OwnedBy(ctx context.Context, db *gorm.DB, resource Resour if err != nil { return nil, err } - if resource != ClusterResource { - return db.Where("cluster_id IN ?", allowedClusterID), nil + + query := "" + switch resource { + case InfraEnvResource: + query = "infra_envs.user_name = ? OR cluster_id IN (?) OR openshift_cluster_id in (?)" + case EventsResource: + query = "infra_envs.user_name = ? OR events.cluster_id IN (?) OR openshift_cluster_id in (?)" + default: + query = "user_name = ? OR id IN (?) OR openshift_cluster_id in (?)" } - return db.Where("id IN ? OR openshift_cluster_id IN ?", allowedClusterID, allowedClusterUuids), nil + // querys := map[Resource]string{ + // ClusterResource: "user_name = ? OR id IN (?) OR openshift_cluster_id in (?)", + // InfraEnvResource: "infra_envs.user_name = ? OR cluster_id IN (?) OR openshift_cluster_id in (?)", + // EventsResource: "infra_envs.user_name = ? OR events.cluster_id IN (?) OR openshift_cluster_id in (?)", + // } + + // return db.Where(querys[resource], ocm.UserNameFromContext(ctx), allowedClusterID, allowedClusterUuids), nil + return db.Where(query, ocm.UserNameFromContext(ctx), allowedClusterID, allowedClusterUuids), nil } if a.isTenancyEnabled() { return db.Where("org_id = ?", ocm.OrgIDFromContext(ctx)), nil diff --git a/pkg/auth/rhsso_authz_handler_test.go b/pkg/auth/rhsso_authz_handler_test.go index 8d88214c922..0c89a180581 100644 --- a/pkg/auth/rhsso_authz_handler_test.go +++ b/pkg/auth/rhsso_authz_handler_test.go @@ -157,7 +157,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Find(&records) Expect(results.RowsAffected, 4) }) It("admin user - non-empty query", func() { @@ -166,7 +167,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Where("name = ?", "A").Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Where("name = ?", "A").Find(&records) Expect(results.RowsAffected, 3) Expect(AllRecordsHasName(records, "A")).To(BeTrue()) }) @@ -176,7 +178,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedByUser(ctx, db, "user1").Find(&records) + results, _ := handler.OwnedByUser(ctx, db, "", "user1") + results.Find(&records) Expect(results.RowsAffected, 2) Expect(AllRecordsHasUserName(records, "user1")).To(BeTrue()) }) @@ -187,7 +190,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Find(&records) Expect(results.RowsAffected, 2) Expect(AllRecordsHasUserName(records, "user1")).To(BeTrue()) }) @@ -197,7 +201,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Where("name = ?", "A").Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Where("name = ?", "A").Find(&records) Expect(results.RowsAffected, 2) Expect(AllRecordsHasName(records, "A")).To(BeTrue()) Expect(AllRecordsHasUserName(records, "user1")).To(BeTrue()) @@ -208,7 +213,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedByUser(ctx, db, "user1").Find(&records) + results, _ := handler.OwnedByUser(ctx, db, "", "user1") + results.Find(&records) Expect(results.RowsAffected, 2) Expect(AllRecordsHasUserName(records, "user1")).To(BeTrue()) }) @@ -218,7 +224,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedByUser(ctx, db, "user2").Find(&records) + results, _ := handler.OwnedByUser(ctx, db, "", "user2") + results.Find(&records) Expect(results.RowsAffected, 0) }) }) @@ -233,7 +240,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Find(&records) Expect(results.RowsAffected, 4) }) It("admin user - non-empty query", func() { @@ -242,7 +250,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Where("name = ?", "A").Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Where("name = ?", "A").Find(&records) Expect(results.RowsAffected, 3) Expect(AllRecordsHasName(records, "A")).To(BeTrue()) }) @@ -252,7 +261,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Find(&records) Expect(results.RowsAffected, 2) Expect(AllRecordsHasOrgId(records, "org1")).To(BeTrue()) }) @@ -262,7 +272,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedBy(ctx, db).Where("name = ?", "A").Find(&records) + results, _ := handler.OwnedBy(ctx, db, "") + results.Find(&records).Where("name = ?", "A").Find(&records) Expect(results.RowsAffected, 1) Expect(AllRecordsHasName(records, "A")).To(BeTrue()) Expect(AllRecordsHasOrgId(records, "org1")).To(BeTrue()) @@ -274,7 +285,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedByUser(ctx, db, "user1").Find(&records) + results, _ := handler.OwnedByUser(ctx, db, "", "user1") + results.Find(&records) Expect(results.RowsAffected, 1) Expect(AllRecordsHasUserName(records, "user1")).To(BeTrue()) }) @@ -285,7 +297,8 @@ var _ = Describe("OwnedBy", func() { ctx = context.WithValue(ctx, restapi.AuthKey, payload) var records []common.Cluster - results := handler.OwnedByUser(ctx, db, "user2").Find(&records) + results, _ := handler.OwnedByUser(ctx, db, "", "user2") + results.Find(&records) Expect(results.RowsAffected, 0) }) }) diff --git a/pkg/ocm/mock_authorization.go b/pkg/ocm/mock_authorization.go index 31d022c6b73..5f95badd960 100644 --- a/pkg/ocm/mock_authorization.go +++ b/pkg/ocm/mock_authorization.go @@ -49,12 +49,13 @@ func (mr *MockOCMAuthorizationMockRecorder) AccessReview(ctx, username, action, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AccessReview", reflect.TypeOf((*MockOCMAuthorization)(nil).AccessReview), ctx, username, action, subscriptionId, resourceType) } -func (m *MockOCMAuthorization) ResourceReview(ctx context.Context, username, action, resourceType string) ([]string, error) { +func (m *MockOCMAuthorization) ResourceReview(ctx context.Context, username, action, resourceType string) ([]string,[]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AccessReview", ctx, username, action, resourceType) + ret := m.ctrl.Call(m, "ResourceReview", ctx, username, action, resourceType) ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret1, _ := ret[1].([]string) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // AccessReview indicates an expected call of AccessReview.