Skip to content

Commit

Permalink
Test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kian99 committed Jun 10, 2024
1 parent c6e0a2d commit a80cee3
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 13 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
build_test:
name: Build and Test
runs-on: ubuntu-22.04
timeout-minutes: 45
steps:
- uses: actions/checkout@v4
with:
Expand Down
13 changes: 4 additions & 9 deletions internal/db/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,16 @@ func preloadModel(prefix string, db *gorm.DB) *gorm.DB {
return db
}

// GetModelsByControllerID retrieves a list of models hosted on the specified controller.
func (d *Database) GetModelsByControllerID(ctx context.Context, controllerID uint) ([]dbmodel.Model, error) {
const op = errors.Op("db.GetModelsByControllerID")
// GetModelsByController retrieves a list of models hosted on the specified controller.
func (d *Database) GetModelsByController(ctx context.Context, ctl dbmodel.Controller) ([]dbmodel.Model, error) {
const op = errors.Op("db.GetModelsByController")

if err := d.ready(); err != nil {
return nil, errors.E(op, err)
}
var models []dbmodel.Model
db := d.DB.WithContext(ctx)
err := db.Where("controller_id = ?", controllerID).Find(&models).Error
if err != nil {
err = dbError(err)
if errors.ErrorCode(err) == errors.CodeNotFound {
return nil, errors.E(op, err, "model not found")
}
if err := db.Model(ctl).Association("Models").Delete(&models); err != nil {
return nil, errors.E(op, dbError(err))
}
return models, nil
Expand Down
4 changes: 2 additions & 2 deletions internal/db/model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ func (s *dbSuite) TestGetModelsByUUID(c *qt.C) {
c.Check(models[2].Controller.Name, qt.Not(qt.Equals), "")
}

func (s *dbSuite) TestGetModelsByControllerID(c *qt.C) {
func (s *dbSuite) TestGetModelsByController(c *qt.C) {
err := s.Database.Migrate(context.Background(), true)
c.Assert(err, qt.Equals, nil)

Expand Down Expand Up @@ -725,7 +725,7 @@ func (s *dbSuite) TestGetModelsByControllerID(c *qt.C) {
for _, m := range models {
c.Assert(s.Database.DB.Create(&m).Error, qt.IsNil)
}
foundModels, err := s.Database.GetModelsByControllerID(context.Background(), controller.ID)
foundModels, err := s.Database.GetModelsByController(context.Background(), controller)
foundModelNames := []string{}
for _, m := range foundModels {
foundModelNames = append(foundModelNames, m.Name)
Expand Down
8 changes: 8 additions & 0 deletions internal/dbmodel/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,16 @@ func TestCloudRegionControllers(t *testing.T) {
c.Check(crcps, qt.HasLen, 2)
c.Check(crcps[0].Controller, qt.DeepEquals, dbmodel.Controller{
ID: ctl2.ID,
CreatedAt: ctl2.CreatedAt,
UpdatedAt: ctl2.UpdatedAt,
Name: ctl2.Name,
CloudName: ctl2.CloudName,
CloudRegion: ctl2.CloudRegion,
})
c.Check(crcps[1].Controller, qt.DeepEquals, dbmodel.Controller{
ID: ctl1.ID,
CreatedAt: ctl1.CreatedAt,
UpdatedAt: ctl1.UpdatedAt,
Name: ctl1.Name,
CloudName: ctl1.CloudName,
CloudRegion: ctl1.CloudRegion,
Expand All @@ -371,12 +375,16 @@ func TestCloudRegionControllers(t *testing.T) {
c.Check(crcps, qt.HasLen, 2)
c.Check(crcps[0].Controller, qt.DeepEquals, dbmodel.Controller{
ID: ctl1.ID,
CreatedAt: ctl1.CreatedAt,
UpdatedAt: ctl1.UpdatedAt,
Name: ctl1.Name,
CloudName: ctl1.CloudName,
CloudRegion: ctl1.CloudRegion,
})
c.Check(crcps[1].Controller, qt.DeepEquals, dbmodel.Controller{
ID: ctl2.ID,
CreatedAt: ctl2.CreatedAt,
UpdatedAt: ctl2.UpdatedAt,
Name: ctl2.Name,
CloudName: ctl2.CloudName,
CloudRegion: ctl2.CloudRegion,
Expand Down
3 changes: 3 additions & 0 deletions internal/dbmodel/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ type Controller struct {
// controller.
CloudRegions []CloudRegionControllerPriority

// Models contains all the models that are running on this controller.
Models []Model

// TODO(mhilton) Save controller statistics?
}

Expand Down
2 changes: 1 addition & 1 deletion internal/jimm/jimm.go
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ func (j *JIMM) RemoveController(ctx context.Context, user *openfga.User, control
return errors.E(errors.CodeStillAlive, "controller is still alive")
}

models, err := db.GetModelsByControllerID(ctx, c.ID)
models, err := db.GetModelsByController(ctx, c)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions internal/jimm/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ func (w *Watcher) watchController(ctx context.Context, ctl *dbmodel.Controller)
continue
}
if v.changed {
v.changed = false
// Update changed model.
err := w.Database.Transaction(func(tx *db.Database) error {
m := dbmodel.Model{
Expand Down
2 changes: 1 addition & 1 deletion internal/jimmtest/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var DBObjectEquals = qt.CmpEquals(
cmpopts.IgnoreFields(dbmodel.CloudCredential{}, "CloudName", "OwnerIdentityName"),
cmpopts.IgnoreFields(dbmodel.CloudRegion{}, "CloudName"),
cmpopts.IgnoreFields(dbmodel.CloudRegionControllerPriority{}, "CloudRegionID", "ControllerID"),
cmpopts.IgnoreFields(dbmodel.Controller{}, "ID"),
cmpopts.IgnoreFields(dbmodel.Controller{}, "ID", "UpdateAt", "CreatedAt"),
cmpopts.IgnoreFields(dbmodel.Model{}, "ID", "CreatedAt", "UpdatedAt", "OwnerIdentityName", "ControllerID", "CloudRegionID", "CloudCredentialID"),
)

Expand Down

0 comments on commit a80cee3

Please sign in to comment.