From a80cee3d071c609429ad3731c1fa83c3ea3d0007 Mon Sep 17 00:00:00 2001 From: Kian Parvin Date: Mon, 10 Jun 2024 18:00:37 +0200 Subject: [PATCH] Test fixes --- .github/workflows/ci.yaml | 1 + internal/db/model.go | 13 ++++--------- internal/db/model_test.go | 4 ++-- internal/dbmodel/cloud_test.go | 8 ++++++++ internal/dbmodel/controller.go | 3 +++ internal/jimm/jimm.go | 2 +- internal/jimm/watcher.go | 1 + internal/jimmtest/cmp.go | 2 +- 8 files changed, 21 insertions(+), 13 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 868ad4c95..7bd734e10 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,6 +25,7 @@ jobs: build_test: name: Build and Test runs-on: ubuntu-22.04 + timeout-minutes: 45 steps: - uses: actions/checkout@v4 with: diff --git a/internal/db/model.go b/internal/db/model.go index 8d22977a5..adc726662 100644 --- a/internal/db/model.go +++ b/internal/db/model.go @@ -182,21 +182,16 @@ func preloadModel(prefix string, db *gorm.DB) *gorm.DB { return db } -// GetModelsByControllerID retrieves a list of models hosted on the specified controller. -func (d *Database) GetModelsByControllerID(ctx context.Context, controllerID uint) ([]dbmodel.Model, error) { - const op = errors.Op("db.GetModelsByControllerID") +// GetModelsByController retrieves a list of models hosted on the specified controller. +func (d *Database) GetModelsByController(ctx context.Context, ctl dbmodel.Controller) ([]dbmodel.Model, error) { + const op = errors.Op("db.GetModelsByController") if err := d.ready(); err != nil { return nil, errors.E(op, err) } var models []dbmodel.Model db := d.DB.WithContext(ctx) - err := db.Where("controller_id = ?", controllerID).Find(&models).Error - if err != nil { - err = dbError(err) - if errors.ErrorCode(err) == errors.CodeNotFound { - return nil, errors.E(op, err, "model not found") - } + if err := db.Model(ctl).Association("Models").Delete(&models); err != nil { return nil, errors.E(op, dbError(err)) } return models, nil diff --git a/internal/db/model_test.go b/internal/db/model_test.go index 41a44d81a..1f4848cca 100644 --- a/internal/db/model_test.go +++ b/internal/db/model_test.go @@ -639,7 +639,7 @@ func (s *dbSuite) TestGetModelsByUUID(c *qt.C) { c.Check(models[2].Controller.Name, qt.Not(qt.Equals), "") } -func (s *dbSuite) TestGetModelsByControllerID(c *qt.C) { +func (s *dbSuite) TestGetModelsByController(c *qt.C) { err := s.Database.Migrate(context.Background(), true) c.Assert(err, qt.Equals, nil) @@ -725,7 +725,7 @@ func (s *dbSuite) TestGetModelsByControllerID(c *qt.C) { for _, m := range models { c.Assert(s.Database.DB.Create(&m).Error, qt.IsNil) } - foundModels, err := s.Database.GetModelsByControllerID(context.Background(), controller.ID) + foundModels, err := s.Database.GetModelsByController(context.Background(), controller) foundModelNames := []string{} for _, m := range foundModels { foundModelNames = append(foundModelNames, m.Name) diff --git a/internal/dbmodel/cloud_test.go b/internal/dbmodel/cloud_test.go index f4788edad..3f0d174bf 100644 --- a/internal/dbmodel/cloud_test.go +++ b/internal/dbmodel/cloud_test.go @@ -354,12 +354,16 @@ func TestCloudRegionControllers(t *testing.T) { c.Check(crcps, qt.HasLen, 2) c.Check(crcps[0].Controller, qt.DeepEquals, dbmodel.Controller{ ID: ctl2.ID, + CreatedAt: ctl2.CreatedAt, + UpdatedAt: ctl2.UpdatedAt, Name: ctl2.Name, CloudName: ctl2.CloudName, CloudRegion: ctl2.CloudRegion, }) c.Check(crcps[1].Controller, qt.DeepEquals, dbmodel.Controller{ ID: ctl1.ID, + CreatedAt: ctl1.CreatedAt, + UpdatedAt: ctl1.UpdatedAt, Name: ctl1.Name, CloudName: ctl1.CloudName, CloudRegion: ctl1.CloudRegion, @@ -371,12 +375,16 @@ func TestCloudRegionControllers(t *testing.T) { c.Check(crcps, qt.HasLen, 2) c.Check(crcps[0].Controller, qt.DeepEquals, dbmodel.Controller{ ID: ctl1.ID, + CreatedAt: ctl1.CreatedAt, + UpdatedAt: ctl1.UpdatedAt, Name: ctl1.Name, CloudName: ctl1.CloudName, CloudRegion: ctl1.CloudRegion, }) c.Check(crcps[1].Controller, qt.DeepEquals, dbmodel.Controller{ ID: ctl2.ID, + CreatedAt: ctl2.CreatedAt, + UpdatedAt: ctl2.UpdatedAt, Name: ctl2.Name, CloudName: ctl2.CloudName, CloudRegion: ctl2.CloudRegion, diff --git a/internal/dbmodel/controller.go b/internal/dbmodel/controller.go index d6a6e3750..d88963ff4 100644 --- a/internal/dbmodel/controller.go +++ b/internal/dbmodel/controller.go @@ -82,6 +82,9 @@ type Controller struct { // controller. CloudRegions []CloudRegionControllerPriority + // Models contains all the models that are running on this controller. + Models []Model + // TODO(mhilton) Save controller statistics? } diff --git a/internal/jimm/jimm.go b/internal/jimm/jimm.go index 783fed7c6..211978dce 100644 --- a/internal/jimm/jimm.go +++ b/internal/jimm/jimm.go @@ -527,7 +527,7 @@ func (j *JIMM) RemoveController(ctx context.Context, user *openfga.User, control return errors.E(errors.CodeStillAlive, "controller is still alive") } - models, err := db.GetModelsByControllerID(ctx, c.ID) + models, err := db.GetModelsByController(ctx, c) if err != nil { return err } diff --git a/internal/jimm/watcher.go b/internal/jimm/watcher.go index be6f6d403..3ba56ae2b 100644 --- a/internal/jimm/watcher.go +++ b/internal/jimm/watcher.go @@ -309,6 +309,7 @@ func (w *Watcher) watchController(ctx context.Context, ctl *dbmodel.Controller) continue } if v.changed { + v.changed = false // Update changed model. err := w.Database.Transaction(func(tx *db.Database) error { m := dbmodel.Model{ diff --git a/internal/jimmtest/cmp.go b/internal/jimmtest/cmp.go index 404bef2fb..978b695ae 100644 --- a/internal/jimmtest/cmp.go +++ b/internal/jimmtest/cmp.go @@ -45,7 +45,7 @@ var DBObjectEquals = qt.CmpEquals( cmpopts.IgnoreFields(dbmodel.CloudCredential{}, "CloudName", "OwnerIdentityName"), cmpopts.IgnoreFields(dbmodel.CloudRegion{}, "CloudName"), cmpopts.IgnoreFields(dbmodel.CloudRegionControllerPriority{}, "CloudRegionID", "ControllerID"), - cmpopts.IgnoreFields(dbmodel.Controller{}, "ID"), + cmpopts.IgnoreFields(dbmodel.Controller{}, "ID", "UpdateAt", "CreatedAt"), cmpopts.IgnoreFields(dbmodel.Model{}, "ID", "CreatedAt", "UpdatedAt", "OwnerIdentityName", "ControllerID", "CloudRegionID", "CloudCredentialID"), )