From 98cc7b0fe23be3f2aa4c2494fab631ef89c144c0 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 12 Feb 2024 09:09:28 -0800 Subject: [PATCH] feat: fixup migration visiblity for ordering (#408) --- cmd/api/src/database/db.go | 5 +++++ cmd/api/src/database/migration/stepwise.go | 18 +++++++++++++++ cmd/api/src/database/mocks/db.go | 15 +++++++++++++ cmd/api/src/services/entrypoint.go | 2 ++ packages/go/dawgs/drivers/neo4j/driver.go | 10 +++++++++ packages/go/dawgs/drivers/pg/driver.go | 26 ++++++++++++---------- packages/go/dawgs/drivers/pg/manager.go | 14 ++++++++++++ packages/go/dawgs/graph/graph.go | 3 +++ packages/go/dawgs/graph/mocks/graph.go | 14 ++++++++++++ packages/go/dawgs/graph/switch.go | 7 ++++++ 10 files changed, 102 insertions(+), 12 deletions(-) diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 6f98839499..08636ced68 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -82,6 +82,7 @@ type Database interface { RawFirst(value any) error Wipe() error Migrate() error + RequiresMigration() (bool, error) CreateAuditLog(auditLog model.AuditLog) error AppendAuditLog(ctx context.Context, entry model.AuditEntry) error ListAuditLogs(before, after time.Time, offset, limit int, order string, filter model.SQLFilter) (model.AuditLogs, int, error) @@ -226,6 +227,10 @@ func (s *BloodhoundDB) Wipe() error { }) } +func (s *BloodhoundDB) RequiresMigration() (bool, error) { + return migration.NewMigrator(s.db).RequiresMigration() +} + func (s *BloodhoundDB) Migrate() error { // Run the migrator if err := migration.NewMigrator(s.db).Migrate(); err != nil { diff --git a/cmd/api/src/database/migration/stepwise.go b/cmd/api/src/database/migration/stepwise.go index 549b380712..ac24b56067 100644 --- a/cmd/api/src/database/migration/stepwise.go +++ b/cmd/api/src/database/migration/stepwise.go @@ -127,6 +127,24 @@ ALTER TABLE ONLY migrations ADD CONSTRAINT migrations_pkey PRIMARY KEY (id);` return nil } +func (s *Migrator) RequiresMigration() (bool, error) { + // check if migration table exists to determine type of manifest to generate + if hasTable, err := s.HasMigrationTable(); err != nil { + return false, fmt.Errorf("failed to check if migration table exists: %w", err) + } else if !hasTable { + // no migration table, assume this is new installation and requires migration + return true, nil + } + + if lastMigration, err := s.LatestMigration(); err != nil { + return false, fmt.Errorf("could not get latest migration: %w", err) + } else if manifest, err := s.GenerateManifestAfterVersion(lastMigration.Version()); err != nil { + return false, fmt.Errorf("failed to generate migration manifest from previous version: %w", err) + } else { + return len(manifest.VersionTable) > 0, nil + } +} + // executeStepwiseMigrations will run all necessary migrations for a deployment. // It begins by checking if migration schema exists. If it does not, we assume the // deployment is a new installation, otherwise we assume it may have migration updates. diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 86cfc70612..41df83f043 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -1304,6 +1304,21 @@ func (mr *MockDatabaseMockRecorder) RawFirst(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RawFirst", reflect.TypeOf((*MockDatabase)(nil).RawFirst), arg0) } +// RequiresMigration mocks base method. +func (m *MockDatabase) RequiresMigration() (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RequiresMigration") + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RequiresMigration indicates an expected call of RequiresMigration. +func (mr *MockDatabaseMockRecorder) RequiresMigration() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequiresMigration", reflect.TypeOf((*MockDatabase)(nil).RequiresMigration)) +} + // SavedQueryBelongsToUser mocks base method. func (m *MockDatabase) SavedQueryBelongsToUser(arg0 uuid.UUID, arg1 int) (bool, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index 57e7b22445..2a8712f6bd 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -74,6 +74,8 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots } else if err := bootstrap.MigrateGraph(ctx, connections.Graph, schema.DefaultGraphSchema()); err != nil { return nil, fmt.Errorf("graph migration error: %w", err) } + } else if err := connections.Graph.SetDefaultGraph(ctx, schema.DefaultGraph()); err != nil { + return nil, fmt.Errorf("no default graph found but migrations are disabled per configuration: %w", err) } else { log.Infof("Database migrations are disabled per configuration") } diff --git a/packages/go/dawgs/drivers/neo4j/driver.go b/packages/go/dawgs/drivers/neo4j/driver.go index 58a811b77e..f01a2e1d35 100644 --- a/packages/go/dawgs/drivers/neo4j/driver.go +++ b/packages/go/dawgs/drivers/neo4j/driver.go @@ -135,6 +135,16 @@ func (s *driver) AssertSchema(ctx context.Context, schema graph.Schema) error { return assertSchema(ctx, s, schema) } +func (s *driver) SetDefaultGraph(ctx context.Context, schema graph.Graph) error { + // Note: Neo4j does not support isolated physical graph namespaces. Namespacing can be emulated with Kinds but will + // not be supported for this driver since the fallback behavior is no different from storing all graph data in the + // same namespace. + // + // This is different for the PostgreSQL driver, specifically, since the driver in question supports on-disk + // isolation of graph namespaces. + return nil +} + func (s *driver) Run(ctx context.Context, query string, parameters map[string]any) error { return s.WriteTransaction(ctx, func(tx graph.Transaction) error { result := tx.Raw(query, parameters) diff --git a/packages/go/dawgs/drivers/pg/driver.go b/packages/go/dawgs/drivers/pg/driver.go index 7dd7d86a3f..f68babca3f 100644 --- a/packages/go/dawgs/drivers/pg/driver.go +++ b/packages/go/dawgs/drivers/pg/driver.go @@ -58,16 +58,16 @@ type Driver struct { batchWriteSize int } -func (s *Driver) KindMapper() KindMapper { - return s.schemaManager -} - func (s *Driver) SetDefaultGraph(ctx context.Context, graphSchema graph.Graph) error { - return s.WriteTransaction(ctx, func(tx graph.Transaction) error { - return s.schemaManager.AssertDefaultGraph(tx, graphSchema) + return s.ReadTransaction(ctx, func(tx graph.Transaction) error { + return s.schemaManager.SetDefaultGraph(tx, graphSchema) }) } +func (s *Driver) KindMapper() KindMapper { + return s.schemaManager +} + func (s *Driver) SetBatchWriteSize(size int) { s.batchWriteSize = size } @@ -177,7 +177,13 @@ func (s *Driver) FetchSchema(ctx context.Context) (graph.Schema, error) { func (s *Driver) AssertSchema(ctx context.Context, schema graph.Schema) error { if err := s.WriteTransaction(ctx, func(tx graph.Transaction) error { - return s.schemaManager.AssertSchema(tx, schema) + if err := s.schemaManager.AssertSchema(tx, schema); err != nil { + return err + } else if schema.DefaultGraph.Name != "" { + return s.schemaManager.AssertDefaultGraph(tx, schema.DefaultGraph) + } + + return nil }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)); err != nil { return err } else { @@ -185,11 +191,7 @@ func (s *Driver) AssertSchema(ctx context.Context, schema graph.Schema) error { s.pool.Reset() } - if schema.DefaultGraph.Name == "" { - return nil - } - - return s.SetDefaultGraph(ctx, schema.DefaultGraph) + return nil } func (s *Driver) Run(ctx context.Context, query string, parameters map[string]any) error { diff --git a/packages/go/dawgs/drivers/pg/manager.go b/packages/go/dawgs/drivers/pg/manager.go index 0712d63231..8c345e611b 100644 --- a/packages/go/dawgs/drivers/pg/manager.go +++ b/packages/go/dawgs/drivers/pg/manager.go @@ -173,6 +173,20 @@ func (s *SchemaManager) AssertKinds(tx graph.Transaction, kinds graph.Kinds) ([] return kindIDs, nil } +func (s *SchemaManager) SetDefaultGraph(tx graph.Transaction, schema graph.Graph) error { + // Validate the schema if the graph already exists in the database + if definition, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { + return err + } else { + s.graphs[schema.Name] = definition + + s.defaultGraph = definition + s.hasDefaultGraph = true + } + + return nil +} + func (s *SchemaManager) AssertDefaultGraph(tx graph.Transaction, schema graph.Graph) error { if graphInstance, err := s.AssertGraph(tx, schema); err != nil { return err diff --git a/packages/go/dawgs/graph/graph.go b/packages/go/dawgs/graph/graph.go index 1a6446bcbf..a655447923 100644 --- a/packages/go/dawgs/graph/graph.go +++ b/packages/go/dawgs/graph/graph.go @@ -414,6 +414,9 @@ type Database interface { // AssertSchema will apply the given schema to the underlying database. AssertSchema(ctx context.Context, dbSchema Schema) error + // SetDefaultGraph sets the default graph namespace for the connection. + SetDefaultGraph(ctx context.Context, graphSchema Graph) error + // Run allows a user to pass statements directly to the database. Since results may rely on a transactional context // only an error is returned from this function Run(ctx context.Context, query string, parameters map[string]any) error diff --git a/packages/go/dawgs/graph/mocks/graph.go b/packages/go/dawgs/graph/mocks/graph.go index 274802fb5f..4a7681a817 100644 --- a/packages/go/dawgs/graph/mocks/graph.go +++ b/packages/go/dawgs/graph/mocks/graph.go @@ -767,6 +767,20 @@ func (mr *MockDatabaseMockRecorder) SetBatchWriteSize(interval interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBatchWriteSize", reflect.TypeOf((*MockDatabase)(nil).SetBatchWriteSize), interval) } +// SetDefaultGraph mocks base method. +func (m *MockDatabase) SetDefaultGraph(ctx context.Context, graphSchema graph.Graph) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDefaultGraph", ctx, graphSchema) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDefaultGraph indicates an expected call of SetDefaultGraph. +func (mr *MockDatabaseMockRecorder) SetDefaultGraph(ctx, graphSchema interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDefaultGraph", reflect.TypeOf((*MockDatabase)(nil).SetDefaultGraph), ctx, graphSchema) +} + // SetWriteFlushSize mocks base method. func (m *MockDatabase) SetWriteFlushSize(interval int) { m.ctrl.T.Helper() diff --git a/packages/go/dawgs/graph/switch.go b/packages/go/dawgs/graph/switch.go index 22fdc11194..0ef88dcd00 100644 --- a/packages/go/dawgs/graph/switch.go +++ b/packages/go/dawgs/graph/switch.go @@ -46,6 +46,13 @@ func NewDatabaseSwitch(ctx context.Context, initialDB Database) *DatabaseSwitch } } +func (s *DatabaseSwitch) SetDefaultGraph(ctx context.Context, graphSchema Graph) error { + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.SetDefaultGraph(ctx, graphSchema) +} + func (s *DatabaseSwitch) Switch(db Database) { s.inSwitch = true