diff --git a/internal/db/group.go b/internal/db/group.go index 671145d56..dacc1f730 100644 --- a/internal/db/group.go +++ b/internal/db/group.go @@ -108,8 +108,8 @@ func (d *Database) ListGroups(ctx context.Context, limit, offset int, match stri return groups, nil } -// UpdateGroup updates the group identified by its ID. -func (d *Database) UpdateGroup(ctx context.Context, group *dbmodel.GroupEntry) (err error) { +// UpdateGroupName updates the group name identified by its ID or UUID. +func (d *Database) UpdateGroupName(ctx context.Context, group *dbmodel.GroupEntry) (err error) { const op = errors.Op("db.UpdateGroup") if group.ID == 0 { diff --git a/internal/db/group_test.go b/internal/db/group_test.go index 1e852b837..0a596912b 100644 --- a/internal/db/group_test.go +++ b/internal/db/group_test.go @@ -119,8 +119,8 @@ func (s *dbSuite) TestGetGroup(c *qt.C) { c.Assert(group.UUID, qt.Equals, uuid2) } -func (s *dbSuite) TestUpdateGroup(c *qt.C) { - err := s.Database.UpdateGroup(context.Background(), &dbmodel.GroupEntry{Name: "test-group"}) +func (s *dbSuite) TestUpdateGroupName(c *qt.C) { + err := s.Database.UpdateGroupName(context.Background(), &dbmodel.GroupEntry{Name: "test-group"}) c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound) err = s.Database.Migrate(context.Background(), false) @@ -130,7 +130,7 @@ func (s *dbSuite) TestUpdateGroup(c *qt.C) { Name: "test-group", } - err = s.Database.UpdateGroup(context.Background(), ge) + err = s.Database.UpdateGroupName(context.Background(), ge) c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound) _, err = s.Database.AddGroup(context.Background(), "test-group") @@ -143,7 +143,7 @@ func (s *dbSuite) TestUpdateGroup(c *qt.C) { c.Assert(err, qt.IsNil) ge1.Name = "renamed-group" - err = s.Database.UpdateGroup(context.Background(), ge1) + err = s.Database.UpdateGroupName(context.Background(), ge1) c.Check(err, qt.IsNil) ge2 := &dbmodel.GroupEntry{ @@ -184,7 +184,7 @@ func (s *dbSuite) TestRemoveGroup(c *qt.C) { c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound) } -func (s *dbSuite) TestForEachGroup(c *qt.C) { +func (s *dbSuite) TestListGroups(c *qt.C) { err := s.Database.Migrate(context.Background(), false) c.Assert(err, qt.IsNil) diff --git a/internal/db/role.go b/internal/db/role.go new file mode 100644 index 000000000..6fe038766 --- /dev/null +++ b/internal/db/role.go @@ -0,0 +1,156 @@ +// Copyright 2024 Canonical. + +package db + +import ( + "context" + + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/servermon" +) + +// AddRole adds a new role. +func (d *Database) AddRole(ctx context.Context, name string) (re *dbmodel.RoleEntry, err error) { + const op = errors.Op("db.AddRole") + if err := d.ready(); err != nil { + return nil, errors.E(op, err) + } + + durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op)) + defer durationObserver() + defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op)) + + re = &dbmodel.RoleEntry{ + Name: name, + UUID: newUUID(), + } + + if err := d.DB.WithContext(ctx).Create(re).Error; err != nil { + return nil, errors.E(op, dbError(err)) + } + return re, nil +} + +// GetRole populates the provided *dbmodel.RoleEntry based on name or UUID. +func (d *Database) GetRole(ctx context.Context, role *dbmodel.RoleEntry) (err error) { + const op = errors.Op("db.GetRole") + if err := d.ready(); err != nil { + return errors.E(op, err) + } + + durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op)) + defer durationObserver() + defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op)) + + if role.UUID == "" && role.Name == "" { + return errors.E(op, "must specify uuid or name") + } + + db := d.DB.WithContext(ctx) + if role.ID != 0 { + db = db.Where("id = ?", role.ID) + } + if role.UUID != "" { + db = db.Where("uuid = ?", role.UUID) + } + if role.Name != "" { + db = db.Where("name = ?", role.Name) + } + if err := db.First(&role).Error; err != nil { + return errors.E(op, dbError(err)) + } + return nil +} + +// UpdateRoleName updates the name of a role identified by UUID. +func (d *Database) UpdateRoleName(ctx context.Context, uuid, name string) (err error) { + const op = errors.Op("db.UpdateRole") + + if uuid == "" { + return errors.E(op, "uuid must be specified") + } + + if err := d.ready(); err != nil { + return errors.E(op, err) + } + + durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op)) + defer durationObserver() + defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op)) + + model := d.DB.WithContext(ctx).Model(&dbmodel.RoleEntry{}) + model.Where("uuid = ?", uuid) + if model.Update("name", name).RowsAffected == 0 { + return errors.E(op, errors.CodeNotFound, "role not found") + } + + return nil +} + +// RemoveRole removes the role identified by its ID or UUID. +func (d *Database) RemoveRole(ctx context.Context, role *dbmodel.RoleEntry) (err error) { + const op = errors.Op("db.RemoveRole") + + if role.ID == 0 && role.UUID == "" { + return errors.E("neither role UUID or ID specified", errors.CodeNotFound) + } + + if err := d.ready(); err != nil { + return errors.E(op, err) + } + + durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op)) + defer durationObserver() + defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op)) + + if err := d.DB.WithContext(ctx).Delete(role).Error; err != nil { + return errors.E(op, dbError(err)) + } + return nil +} + +// ListRoles returns a paginated list of Roles defined by limit and offset. +// match is used to fuzzy find based on entries' name or uuid using the LIKE operator (ex. LIKE %%). +func (d *Database) ListRoles(ctx context.Context, limit, offset int, match string) (_ []dbmodel.RoleEntry, err error) { + const op = errors.Op("db.ListRoles") + if err := d.ready(); err != nil { + return nil, errors.E(op, err) + } + + durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op)) + defer durationObserver() + defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op)) + + db := d.DB.WithContext(ctx) + if match != "" { + db = db.Where("name LIKE ? OR uuid LIKE ?", "%"+match+"%", "%"+match+"%") + } + db = db.Order("name asc") + db = db.Limit(limit) + db = db.Offset(offset) + var Roles []dbmodel.RoleEntry + if err := db.Find(&Roles).Error; err != nil { + return nil, errors.E(op, dbError(err)) + } + return Roles, nil +} + +// CountRoles returns a count of the number of Roles that exist. +func (d *Database) CountRoles(ctx context.Context) (count int, err error) { + const op = errors.Op("db.CountRoles") + if err := d.ready(); err != nil { + return 0, errors.E(op, err) + } + durationObserver := servermon.DurationObserver(servermon.DBQueryDurationHistogram, string(op)) + defer durationObserver() + defer servermon.ErrorCounter(servermon.DBQueryErrorCount, &err, string(op)) + + var c int64 + var g dbmodel.RoleEntry + if err := d.DB.WithContext(ctx).Model(g).Count(&c).Error; err != nil { + return 0, errors.E(op, dbError(err)) + } + count = int(c) + return count, nil +} diff --git a/internal/db/role_test.go b/internal/db/role_test.go new file mode 100644 index 000000000..fc56db4a8 --- /dev/null +++ b/internal/db/role_test.go @@ -0,0 +1,203 @@ +// Copyright 2024 Canonical. + +package db_test + +import ( + "context" + "fmt" + + qt "github.com/frankban/quicktest" + "github.com/google/uuid" + + "github.com/canonical/jimm/v3/internal/db" + "github.com/canonical/jimm/v3/internal/dbmodel" + "github.com/canonical/jimm/v3/internal/errors" + "github.com/canonical/jimm/v3/internal/testutils/jimmtest" +) + +func (s *dbSuite) TestAddRole(c *qt.C) { + ctx := context.Background() + + uuid := uuid.NewString() + c.Patch(db.NewUUID, func() string { + return uuid + }) + + _, err := s.Database.AddRole(ctx, "test-role") + c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeUpgradeInProgress) + + err = s.Database.Migrate(context.Background(), false) + c.Assert(err, qt.IsNil) + + roleEntry, err := s.Database.AddRole(ctx, "test-role") + c.Assert(err, qt.IsNil) + c.Assert(roleEntry.UUID, qt.Not(qt.Equals), "") + + _, err = s.Database.AddRole(ctx, "test-role") + c.Assert(errors.ErrorCode(err), qt.Equals, errors.CodeAlreadyExists) + + re := dbmodel.RoleEntry{ + Name: "test-role", + } + tx := s.Database.DB.First(&re) + c.Assert(tx.Error, qt.IsNil) + c.Assert(re.ID, qt.Equals, uint(1)) + c.Assert(re.Name, qt.Equals, "test-role") + c.Assert(re.UUID, qt.Equals, uuid) +} + +func (s *dbSuite) TestGetRole(c *qt.C) { + uuid1 := uuid.NewString() + c.Patch(db.NewUUID, func() string { + return uuid1 + }) + + err := s.Database.GetRole(context.Background(), &dbmodel.RoleEntry{}) + c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeUpgradeInProgress) + + err = s.Database.Migrate(context.Background(), false) + c.Assert(err, qt.IsNil) + + role := &dbmodel.RoleEntry{} + err = s.Database.GetRole(context.Background(), role) + c.Check(err, qt.ErrorMatches, "must specify uuid or name") + + re1, err := s.Database.AddRole(context.TODO(), "test-role") + c.Assert(err, qt.IsNil) + c.Assert(re1.UUID, qt.Equals, uuid1) + + // Get by UUID + re2 := &dbmodel.RoleEntry{ + UUID: uuid1, + } + err = s.Database.GetRole(context.Background(), re2) + c.Assert(err, qt.IsNil) + c.Assert(re1, jimmtest.DBObjectEquals, re2) + + // Get by name + re3 := &dbmodel.RoleEntry{ + Name: "test-role", + } + err = s.Database.GetRole(context.Background(), re3) + c.Assert(err, qt.IsNil) + c.Assert(re1, jimmtest.DBObjectEquals, re3) +} + +func (s *dbSuite) TestUpdateRoleName(c *qt.C) { + err := s.Database.Migrate(context.Background(), false) + c.Assert(err, qt.IsNil) + + err = s.Database.UpdateRoleName(context.Background(), "blah", "blah") + c.Check(err, qt.ErrorMatches, "role not found") + c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound) + + err = s.Database.UpdateRoleName(context.Background(), "", "") + c.Check(err, qt.ErrorMatches, "uuid must be specified") + + _, err = s.Database.AddRole(context.Background(), "test-role") + c.Assert(err, qt.IsNil) + + re1 := &dbmodel.RoleEntry{ + Name: "test-role", + } + err = s.Database.GetRole(context.Background(), re1) + c.Assert(err, qt.IsNil) + + err = s.Database.UpdateRoleName(context.Background(), re1.UUID, "renamed-role") + c.Check(err, qt.IsNil) + + re2 := &dbmodel.RoleEntry{ + UUID: re1.UUID, + } + err = s.Database.GetRole(context.Background(), re2) + c.Check(err, qt.IsNil) + c.Assert(re2.Name, qt.Equals, "renamed-role") +} + +func (s *dbSuite) TestRemoveRole(c *qt.C) { + err := s.Database.RemoveRole(context.Background(), &dbmodel.RoleEntry{Name: "test-role"}) + c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound) + + err = s.Database.Migrate(context.Background(), false) + c.Assert(err, qt.IsNil) + + re := &dbmodel.RoleEntry{ + Name: "test-role", + } + err = s.Database.RemoveRole(context.Background(), re) + c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound) + + roleEntry, err := s.Database.AddRole(context.Background(), re.Name) + c.Assert(err, qt.IsNil) + + ge1 := &dbmodel.RoleEntry{ + Name: "test-role", + } + err = s.Database.GetRole(context.Background(), ge1) + c.Assert(err, qt.IsNil) + c.Assert(roleEntry.UUID, qt.Equals, ge1.UUID) + + err = s.Database.RemoveRole(context.Background(), ge1) + c.Check(err, qt.IsNil) + + err = s.Database.GetRole(context.Background(), ge1) + c.Check(errors.ErrorCode(err), qt.Equals, errors.CodeNotFound) +} + +func (s *dbSuite) TestListRole(c *qt.C) { + err := s.Database.Migrate(context.Background(), false) + c.Assert(err, qt.IsNil) + + addNRoles := 10 + for i := range addNRoles { + _, err := s.Database.AddRole(context.Background(), fmt.Sprintf("test-role-%d", i)) + c.Assert(err, qt.IsNil) + } + ctx := context.Background() + firstRoles, err := s.Database.ListRoles(ctx, 5, 0, "") + c.Assert(err, qt.IsNil) + for i := 0; i < 5; i++ { + c.Assert(firstRoles[i].Name, qt.Equals, fmt.Sprintf("test-role-%d", i)) + } + secondRoles, err := s.Database.ListRoles(ctx, 5, 5, "") + c.Assert(err, qt.IsNil) + for i := 0; i < 5; i++ { + c.Assert(secondRoles[i].Name, qt.Equals, fmt.Sprintf("test-role-%d", i+5)) + } + + matchedRoles, err := s.Database.ListRoles(ctx, 5, 0, "role-1") + c.Assert(err, qt.IsNil) + c.Assert(matchedRoles, qt.HasLen, 1) + c.Assert(matchedRoles[0].Name, qt.Equals, "test-role-1") + + matchedRoles, err = s.Database.ListRoles(ctx, 5, 0, "%not-existing%") + c.Assert(err, qt.IsNil) + c.Assert(matchedRoles, qt.HasLen, 0) + + tg, err := s.Database.AddRole(context.Background(), "\\%test-role") + c.Assert(err, qt.IsNil) + + matchedRoles, err = s.Database.ListRoles(ctx, 5, 0, "\\%t") + c.Assert(err, qt.IsNil) + c.Assert(matchedRoles, qt.HasLen, 1) + c.Assert(matchedRoles[0].UUID, qt.Equals, tg.UUID) + + matchedRoles, err = s.Database.ListRoles(ctx, 5, 0, tg.UUID) + c.Assert(err, qt.IsNil) + c.Assert(matchedRoles, qt.HasLen, 1) + c.Assert(matchedRoles[0].UUID, qt.Equals, tg.UUID) +} + +func (s *dbSuite) TestCountRoles(c *qt.C) { + err := s.Database.Migrate(context.Background(), false) + c.Assert(err, qt.IsNil) + + addNRoles := 10 + for i := range addNRoles { + _, err := s.Database.AddRole(context.Background(), fmt.Sprintf("test-role-%d", i)) + c.Assert(err, qt.IsNil) + } + count, err := s.Database.CountRoles(context.Background()) + c.Assert(err, qt.IsNil) + c.Assert(count, qt.Equals, addNRoles) +} diff --git a/internal/jimm/access.go b/internal/jimm/access.go index 11fe0a9a0..2b7d6f02f 100644 --- a/internal/jimm/access.go +++ b/internal/jimm/access.go @@ -757,7 +757,7 @@ func (j *JIMM) RenameGroup(ctx context.Context, user *openfga.User, oldName, new } group.Name = newName - if err := j.Database.UpdateGroup(ctx, group); err != nil { + if err := j.Database.UpdateGroupName(ctx, group); err != nil { return errors.E(op, err) } return nil diff --git a/internal/testutils/jimmtest/cmp.go b/internal/testutils/jimmtest/cmp.go index a100e18df..2c8f11f96 100644 --- a/internal/testutils/jimmtest/cmp.go +++ b/internal/testutils/jimmtest/cmp.go @@ -48,6 +48,7 @@ var DBObjectEquals = qt.CmpEquals( cmpopts.IgnoreFields(dbmodel.Controller{}, "ID", "UpdatedAt", "CreatedAt"), cmpopts.IgnoreFields(dbmodel.Model{}, "ID", "CreatedAt", "UpdatedAt", "OwnerIdentityName", "ControllerID", "CloudRegionID", "CloudCredentialID"), cmpopts.IgnoreFields(dbmodel.ApplicationOffer{}, "ID", "CreatedAt", "UpdatedAt", "ModelID"), + cmpopts.IgnoreFields(dbmodel.RoleEntry{}, "CreatedAt", "UpdatedAt"), ) // CmpEquals uses cmp.Diff (see http://godoc.org/github.com/google/go-cmp/cmp#Diff)