From 5d6f959e6c93153fef11da89246a4dfc5fbf77bb Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 18 Dec 2024 14:55:28 -0800 Subject: [PATCH] CySQL Support and fixes (#1020) * feat: BED-5159 - support coalesce function * fix: BED-5173 - use correct pgsql regex operator * chore: BED-5168 - enable schema manager to act independently to allow inline kind lookups and asserts in select parts of dawgs * chore: BED-5168 - update schema manager calls to support looking up the latest information in the database if there is a kind lookup failure * chore: BED-5168 - prepare for review --- cmd/api/src/api/tools/pg.go | 6 +- cmd/api/src/api/v2/integration/ingest.go | 6 +- .../src/api/v2/integration/reconciliation.go | 3 +- cmd/api/src/test/integration/context.go | 2 +- cmd/api/src/test/integration/dawgs.go | 5 +- .../harnesses/enrollonbehalfof-1.svg | 4 +- .../harnesses/enrollonbehalfof-2.svg | 4 +- .../harnesses/enrollonbehalfof-3.svg | 18 ++ packages/go/cypher/models/cypher/functions.go | 1 + .../go/cypher/models/pgsql/format/format.go | 2 +- packages/go/cypher/models/pgsql/model.go | 4 +- packages/go/cypher/models/pgsql/operators.go | 18 +- packages/go/cypher/models/pgsql/pgtypes.go | 10 + .../go/cypher/models/pgsql/test/kindmapper.go | 55 ---- .../go/cypher/models/pgsql/test/testcase.go | 3 +- .../pgsql/test/translation_cases/delete.sql | 4 +- .../test/translation_cases/harness.cypher | 8 - .../pgsql/test/translation_cases/harness.sql | 43 --- .../pgsql/test/translation_cases/nodes.sql | 40 ++- .../translation_cases/pattern_binding.sql | 6 +- .../translation_cases/pattern_expansion.sql | 30 +-- .../test/translation_cases/shortest_paths.sql | 2 +- .../translation_cases/stepwise_traversal.sql | 18 +- .../pgsql/test/translation_cases/update.sql | 2 +- .../models/pgsql/translate/expression.go | 50 +++- .../cypher/models/pgsql/translate/format.go | 5 +- .../go/cypher/models/pgsql/translate/node.go | 5 +- .../models/pgsql/translate/relationship.go | 6 +- .../models/pgsql/translate/translation.go | 63 ++++- .../models/pgsql/translate/translator.go | 40 ++- .../models/pgsql/translate/translator_test.go | 13 +- .../cypher/models/pgsql/translate/update.go | 8 +- .../pgsql/visualization/visualizer_test.go | 7 +- packages/go/dawgs/drivers/pg/batch.go | 40 +-- packages/go/dawgs/drivers/pg/driver.go | 26 +- packages/go/dawgs/drivers/pg/manager.go | 250 ++++++++++++------ packages/go/dawgs/drivers/pg/mapper.go | 13 +- packages/go/dawgs/drivers/pg/node_test.go | 65 ++--- packages/go/dawgs/drivers/pg/pg.go | 10 +- .../go/dawgs/drivers/pg/pgutil/kindmapper.go | 112 ++++++++ packages/go/dawgs/drivers/pg/query.go | 2 +- packages/go/dawgs/drivers/pg/result.go | 19 +- packages/go/dawgs/drivers/pg/transaction.go | 7 +- packages/go/dawgs/drivers/pg/types.go | 19 +- 44 files changed, 648 insertions(+), 406 deletions(-) create mode 100644 cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg delete mode 100644 packages/go/cypher/models/pgsql/test/kindmapper.go delete mode 100644 packages/go/cypher/models/pgsql/test/translation_cases/harness.cypher delete mode 100644 packages/go/cypher/models/pgsql/test/translation_cases/harness.sql create mode 100644 packages/go/dawgs/drivers/pg/pgutil/kindmapper.go diff --git a/cmd/api/src/api/tools/pg.go b/cmd/api/src/api/tools/pg.go index 718c42e198..fbd597d9af 100644 --- a/cmd/api/src/api/tools/pg.go +++ b/cmd/api/src/api/tools/pg.go @@ -82,10 +82,8 @@ func migrateTypes(ctx context.Context, neoDB, pgDB graph.Database) error { return err } - return pgDB.WriteTransaction(ctx, func(tx graph.Transaction) error { - _, err := pgDB.(*pg.Driver).KindMapper().AssertKinds(tx, append(neoNodeKinds, neoEdgeKinds...)) - return err - }) + _, err := pgDB.(*pg.Driver).KindMapper().AssertKinds(ctx, append(neoNodeKinds, neoEdgeKinds...)) + return err } func convertNeo4jProperties(properties *graph.Properties) error { diff --git a/cmd/api/src/api/v2/integration/ingest.go b/cmd/api/src/api/v2/integration/ingest.go index 02bf2c729a..ec03c20e0a 100644 --- a/cmd/api/src/api/v2/integration/ingest.go +++ b/cmd/api/src/api/v2/integration/ingest.go @@ -21,6 +21,8 @@ import ( "strings" "time" + "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/src/model" "github.com/specterops/bloodhound/src/model/appcfg" @@ -198,7 +200,7 @@ func (s *Context) WaitForDatapipeAnalysis(timeout time.Duration, originalWrapper type IngestAssertion func(testCtrl test.Controller, tx graph.Transaction) func (s *Context) AssertIngest(assertion IngestAssertion) { - graphDB := integration.OpenGraphDB(s.TestCtrl) + graphDB := integration.OpenGraphDB(s.TestCtrl, graphschema.DefaultGraphSchema()) defer graphDB.Close(s.ctx) require.Nil(s.TestCtrl, graphDB.ReadTransaction(s.ctx, func(tx graph.Transaction) error { @@ -208,7 +210,7 @@ func (s *Context) AssertIngest(assertion IngestAssertion) { } func (s *Context) AssertIngestProperties(assertion IngestAssertion) { - graphDB := integration.OpenGraphDB(s.TestCtrl) + graphDB := integration.OpenGraphDB(s.TestCtrl, graphschema.DefaultGraphSchema()) defer graphDB.Close(s.ctx) require.Nil(s.TestCtrl, graphDB.ReadTransaction(s.ctx, func(tx graph.Transaction) error { diff --git a/cmd/api/src/api/v2/integration/reconciliation.go b/cmd/api/src/api/v2/integration/reconciliation.go index f7e767a579..c442a60cc9 100644 --- a/cmd/api/src/api/v2/integration/reconciliation.go +++ b/cmd/api/src/api/v2/integration/reconciliation.go @@ -18,6 +18,7 @@ package integration import ( "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/graphschema" "github.com/specterops/bloodhound/src/test" "github.com/specterops/bloodhound/src/test/integration" "github.com/stretchr/testify/require" @@ -26,7 +27,7 @@ import ( type ReconciliationAssertion func(testCtrl test.Controller, tx graph.Transaction) func (s *Context) AssertReconciliation(assertion ReconciliationAssertion) { - graphDB := integration.OpenGraphDB(s.TestCtrl) + graphDB := integration.OpenGraphDB(s.TestCtrl, graphschema.DefaultGraphSchema()) defer graphDB.Close(s.ctx) require.Nil(s.TestCtrl, graphDB.ReadTransaction(s.ctx, func(tx graph.Transaction) error { diff --git a/cmd/api/src/test/integration/context.go b/cmd/api/src/test/integration/context.go index ca425003e7..ef2e7eae05 100644 --- a/cmd/api/src/test/integration/context.go +++ b/cmd/api/src/test/integration/context.go @@ -67,7 +67,7 @@ func (s *GraphContext) End(t test.Context) { func NewGraphContext(ctx test.Context, schema graph.Schema) *GraphContext { graphContext := &GraphContext{ schema: schema, - Database: OpenGraphDB(ctx), + Database: OpenGraphDB(ctx, schema), } // Initialize the graph context diff --git a/cmd/api/src/test/integration/dawgs.go b/cmd/api/src/test/integration/dawgs.go index 4c87d12fcf..0b2f8b61b2 100644 --- a/cmd/api/src/test/integration/dawgs.go +++ b/cmd/api/src/test/integration/dawgs.go @@ -23,7 +23,6 @@ import ( "github.com/specterops/bloodhound/dawgs/drivers/neo4j" "github.com/specterops/bloodhound/dawgs/drivers/pg" "github.com/specterops/bloodhound/dawgs/graph" - schema "github.com/specterops/bloodhound/graphschema" "github.com/specterops/bloodhound/src/config" "github.com/specterops/bloodhound/src/test" "github.com/specterops/bloodhound/src/test/integration/utils" @@ -39,7 +38,7 @@ func LoadConfiguration(testCtrl test.Controller) config.Configuration { return cfg } -func OpenGraphDB(testCtrl test.Controller) graph.Database { +func OpenGraphDB(testCtrl test.Controller, schema graph.Schema) graph.Database { var ( cfg = LoadConfiguration(testCtrl) graphDatabase graph.Database @@ -62,7 +61,7 @@ func OpenGraphDB(testCtrl test.Controller) graph.Database { } test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err) - test.RequireNilErr(testCtrl, graphDatabase.AssertSchema(context.Background(), schema.DefaultGraphSchema())) + test.RequireNilErr(testCtrl, graphDatabase.AssertSchema(context.Background(), schema)) return graphDatabase } diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg index a86d80d2f0..4f22563a8b 100644 --- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg @@ -1,5 +1,5 @@ -NTAuthStoreForTrustedForNTAuthPublishedToPublishedToPublishedToRootCAForEnterpriseCAForEnrollOnBehalfOfEnrollOnBehalfOfDomain1CertTemplate1-1schemaversion:2ekus:["2.5.29.37.0"]NTAuthStore1EnterpriseCA1CertTemplate1-2schemaversion:1ekus:["2.5.29.37.0"]CertTemplate1-3schemaversion:2ekus:["2.5.29.37.0"]RootCA1 +NTAuthStoreForTrustedForNTAuthPublishedToPublishedToPublishedToRootCAForEnterpriseCAForEnrollOnBehalfOfEnrollOnBehalfOfEnrollOnBehalfOfDomain1CertTemplate1-1schemaversion:2effectiveekus:["2.5.29.37.0"]NTAuthStore1EnterpriseCA1CertTemplate1-2schemaversion:1effectiveekus:["2.5.29.37.0"]CertTemplate1-3schemaversion:2effectiveekus:["2.5.29.37.0"]RootCA1 diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg index dd203c5b22..96181905f2 100644 --- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg @@ -1,5 +1,5 @@ -NTAuthStoreForTrustedForNTAuthRootCAForEnterpriseCAForPublishedToPublishedToPublishedToPublishedToEnrollOnBehalfOfPublishedToEnrollOnBehalfOfDomain2NTAuthStore2EnterpriseCA2RootCA2CertTemplate2-1ekus:["1.3.6.1.4.1.311.20.2.1"]CertTemplate2-2ekus:["1.3.6.1.4.1.311.20.2.1", "2.5.29.37.0"]CertTemplate2-3ekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:["1.3.6.1.4.1.311.20.2.1"]CertTemplate2-4ekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:[]CertTemplate2-5ekus:[]schemaversion:1subjectaltrequiresupn:true +NTAuthStoreForTrustedForNTAuthRootCAForEnterpriseCAForPublishedToPublishedToPublishedToPublishedToEnrollOnBehalfOfDomain2NTAuthStore2EnterpriseCA2RootCA2CertTemplate2-1effectiveekus:["1.3.6.1.4.1.311.20.2.1"]schemaversion:2CertTemplate2-2effectiveekus:["1.3.6.1.4.1.311.20.2.1", "2.5.29.37.0"]schemaversion:2CertTemplate2-3effectiveekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:["1.3.6.1.4.1.311.20.2.1"]CertTemplate2-4effectiveekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:[] diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg new file mode 100644 index 0000000000..42dced49a9 --- /dev/null +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg @@ -0,0 +1,18 @@ + +NTAuthStoreForTrustedForNTAuthPublishedToPublishedToRootCAForEnterpriseCAForEnrollOnBehalfOfPublishedToEnterpriseCAForEnrollOnBehalfOfDomain1CertTemplate1-1schemaversion:2effectiveekus:["2.5.29.37.0"]NTAuthStore1EnterpriseCA1CertTemplate1-2schemaversion:1effectiveekus:["2.5.29.37.0"]CertTemplate1-3schemaversion:2effectiveekus:["2.5.29.37.0"]RootCA1EnterpriseCA2 diff --git a/packages/go/cypher/models/cypher/functions.go b/packages/go/cypher/models/cypher/functions.go index c6dce37f2b..59cd959c9d 100644 --- a/packages/go/cypher/models/cypher/functions.go +++ b/packages/go/cypher/models/cypher/functions.go @@ -33,6 +33,7 @@ const ( ToStringFunction = "tostring" ToIntegerFunction = "toint" ListSizeFunction = "size" + CoalesceFunction = "coalesce" // ITTC - Instant Type; Temporal Component (https://neo4j.com/docs/cypher-manual/current/functions/temporal/) ITTCYear = "year" diff --git a/packages/go/cypher/models/pgsql/format/format.go b/packages/go/cypher/models/pgsql/format/format.go index 4280aef044..0b09b19c46 100644 --- a/packages/go/cypher/models/pgsql/format/format.go +++ b/packages/go/cypher/models/pgsql/format/format.go @@ -206,7 +206,7 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error { exprStack = append(exprStack, *typedNextExpr) case pgsql.FunctionCall: - if typedNextExpr.CastType != pgsql.UnsetDataType { + if typedNextExpr.CastType.IsKnown() { exprStack = append(exprStack, typedNextExpr.CastType, pgsql.FormattingLiteral("::")) } diff --git a/packages/go/cypher/models/pgsql/model.go b/packages/go/cypher/models/pgsql/model.go index c890a855e0..d5f3d06a40 100644 --- a/packages/go/cypher/models/pgsql/model.go +++ b/packages/go/cypher/models/pgsql/model.go @@ -17,6 +17,7 @@ package pgsql import ( + "context" "strings" "github.com/specterops/bloodhound/dawgs/graph" @@ -27,7 +28,8 @@ import ( // KindMapper is an interface that represents a service that can map a given slice of graph.Kind to a slice of // int16 numeric identifiers. type KindMapper interface { - MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) + MapKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) + AssertKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) } // FormattingLiteral is a syntax node that is used as a transparent formatting syntax node. The formatter will diff --git a/packages/go/cypher/models/pgsql/operators.go b/packages/go/cypher/models/pgsql/operators.go index 2f0010285d..6034ad4b37 100644 --- a/packages/go/cypher/models/pgsql/operators.go +++ b/packages/go/cypher/models/pgsql/operators.go @@ -99,13 +99,15 @@ const ( OperatorIs Operator = "is" OperatorIsNot Operator = "is not" OperatorSimilarTo Operator = "similar to" - OperatorRegexMatch Operator = "=~" - OperatorStartsWith Operator = "starts with" - OperatorContains Operator = "contains" - OperatorEndsWith Operator = "ends with" - OperatorPropertyLookup Operator = "property_lookup" + OperatorRegexMatch Operator = "~" + OperatorAssignment Operator = "=" + OperatorAdditionAssignment Operator = "+=" - OperatorAssignment = OperatorEquals - OperatorLabelAssignment Operator = "label_assignment" - OperatorAdditionAssignment Operator = "+=" + OperatorCypherRegexMatch Operator = "=~" + OperatorCypherStartsWith Operator = "starts with" + OperatorCypherContains Operator = "contains" + OperatorCypherEndsWith Operator = "ends with" + + OperatorPropertyLookup Operator = "property_lookup" + OperatorKindAssignment Operator = "kind_assignment" ) diff --git a/packages/go/cypher/models/pgsql/pgtypes.go b/packages/go/cypher/models/pgsql/pgtypes.go index 8839e37971..c50d8656b9 100644 --- a/packages/go/cypher/models/pgsql/pgtypes.go +++ b/packages/go/cypher/models/pgsql/pgtypes.go @@ -111,6 +111,16 @@ const ( ExpansionTerminalNode DataType = "expansion_terminal_node" ) +func (s DataType) IsKnown() bool { + switch s { + case UnsetDataType, UnknownDataType: + return false + + default: + return true + } +} + // TODO: operator, while unused, is part of a refactor for this function to make it operator aware func (s DataType) Compatible(other DataType, operator Operator) (DataType, bool) { if s == other { diff --git a/packages/go/cypher/models/pgsql/test/kindmapper.go b/packages/go/cypher/models/pgsql/test/kindmapper.go deleted file mode 100644 index bcbe000329..0000000000 --- a/packages/go/cypher/models/pgsql/test/kindmapper.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2024 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package test - -import ( - "github.com/specterops/bloodhound/dawgs/graph" -) - -type InMemoryKindMapper struct { - KindToID map[graph.Kind]int16 - IDToKind map[int16]graph.Kind -} - -func NewInMemoryKindMapper() *InMemoryKindMapper { - return &InMemoryKindMapper{ - KindToID: map[graph.Kind]int16{}, - IDToKind: map[int16]graph.Kind{}, - } -} - -func (s *InMemoryKindMapper) MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { - var ( - ids = make([]int16, 0, len(kinds)) - missing = make(graph.Kinds, 0, len(kinds)) - ) - - for _, kind := range kinds { - if id, found := s.KindToID[kind]; !found { - missing = append(missing, kind) - } else { - ids = append(ids, id) - } - } - - return ids, missing -} - -func (s *InMemoryKindMapper) Put(kind graph.Kind, id int16) { - s.KindToID[kind] = id - s.IDToKind[id] = kind -} diff --git a/packages/go/cypher/models/pgsql/test/testcase.go b/packages/go/cypher/models/pgsql/test/testcase.go index afb4ab6095..afdef35a65 100644 --- a/packages/go/cypher/models/pgsql/test/testcase.go +++ b/packages/go/cypher/models/pgsql/test/testcase.go @@ -17,6 +17,7 @@ package test import ( + "context" "embed" "encoding/json" "fmt" @@ -82,7 +83,7 @@ func (s *TranslationTestCase) Assert(t *testing.T, expectedSQL string, kindMappe } } - if translation, err := translate.Translate(regularQuery, kindMapper); err != nil { + if translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil); err != nil { t.Fatalf("Failed to translate cypher query: %s - %v", s.Cypher, err) } else if formattedQuery, err := translate.Translated(translation); err != nil { t.Fatalf("Failed to format SQL translatedQuery: %v", err) diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/delete.sql b/packages/go/cypher/models/pgsql/test/translation_cases/delete.sql index b99b30f971..0ed62e66c9 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/delete.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/delete.sql @@ -29,7 +29,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])), + where e0.kind_id = any (array [3]::int2[])), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id returning (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e0, s0.n0 as n0, s0.n1 as n1) select 1; @@ -49,7 +49,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from s0, edge e1 join node n2 on n2.id = e1.end_id - where e1.kind_id = any (array [11]::int2[]) + where e1.kind_id = any (array [3]::int2[]) and (s0.n1).id = e1.start_id), s2 as (delete from edge e2 using s1 where (s1.e1).id = e2.id returning s1.e0 as e0, (e2.id, e2.start_id, e2.end_id, e2.kind_id, e2.properties)::edgecomposite as e1, s1.n0 as n0, s1.n1 as n1, s1.n2 as n2) diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/harness.cypher b/packages/go/cypher/models/pgsql/test/translation_cases/harness.cypher deleted file mode 100644 index 6719889729..0000000000 --- a/packages/go/cypher/models/pgsql/test/translation_cases/harness.cypher +++ /dev/null @@ -1,8 +0,0 @@ -match (n) detach delete n; - -create (n1:NodeKind1 {name: 'n1'}) -create (n2:NodeKind1 {name: 'n2'}) set n2:NodeKind2 -create (n3:NodeKind1 {name: 'n3'}) set n3:NodeKind2 -create (n4:NodeKind2 {name: 'n4'}) -create (n5:NodeKind2 {name: 'n5'}) -create (n1)-[:EdgeKind1 {name: 'e1', prop: 'a'}]->(n2)-[:EdgeKind1 {name: 'e2', prop: 'a'}]->(n3)-[:EdgeKind1 {name: 'e3'}]->(n4); diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/harness.sql b/packages/go/cypher/models/pgsql/test/translation_cases/harness.sql deleted file mode 100644 index ae2c24d9ce..0000000000 --- a/packages/go/cypher/models/pgsql/test/translation_cases/harness.sql +++ /dev/null @@ -1,43 +0,0 @@ --- Copyright 2024 Specter Ops, Inc. --- --- Licensed under the Apache License, Version 2.0 --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. --- --- SPDX-License-Identifier: Apache-2.0 - -truncate table kind; -truncate table edge; -truncate table node; - -insert into kind (id, name) -values (1, 'NodeKind1'), - (2, 'NodeKind2'), - (11, 'EdgeKind1'), - (12, 'EdgeKind2'), - (13, 'EdgeKind3'); - -insert into node (id, graph_id, kind_ids, properties) -values (1, 1, array [1], '{"name": "n1"}'), - (2, 1, array [1, 2], '{"name": "n2"}'), - (3, 1, array [1, 2], '{"name": "n3"}'), - (4, 1, array [2], '{"name": "n4"}'), - (5, 1, array [2], '{"name": "n5"}'), - (6, 1, array [1, 2], '{"name": "n6"}'); - -insert into edge (graph_id, start_id, end_id, kind_id, properties) -values (1, 1, 2, 11, '{"name": "e1", "prop": "a"}'), - (1, 2, 3, 12, '{"name": "e2", "prop": "a"}'), - (1, 3, 4, 12, '{"name": "e3", "prop": "a"}'), - (1, 4, 5, 12, '{"name": "e4", "prop": "a"}'), - (1, 2, 4, 11, '{"name": "e5", "prop": "a"}'), - (1, 2, 4, 13, '{"name": "e6", "prop": "a"}'), - (1, 3, 4, 11, '{"name": "e7"}'); diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql index 69b09a8b98..02181bfc8a 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql @@ -220,10 +220,10 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 select s0.n0 as s from s0; --- case: match (s) where s.created_at = localtime('12:12:12') return s +-- case: match (s) where s.created_at = localtime('4:4:4') return s with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 - where (n0.properties ->> 'created_at')::time without time zone = ('12:12:12')::time without time zone) + where (n0.properties ->> 'created_at')::time without time zone = ('4:4:4')::time without time zone) select s0.n0 as s from s0; @@ -234,10 +234,10 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 select s0.n0 as s from s0; --- case: match (s) where s.created_at = date('2023-12-12') return s +-- case: match (s) where s.created_at = date('2023-4-4') return s with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 - where (n0.properties ->> 'created_at')::date = ('2023-12-12')::date) + where (n0.properties ->> 'created_at')::date = ('2023-4-4')::date) select s0.n0 as s from s0; @@ -561,3 +561,35 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 or @pi1::text = any (jsonb_to_text_array(n0.properties -> 'array')::text[])) select s0.n0 as n from s0; + +-- case: match (n:NodeKind1) where coalesce(n.system_tags, '') contains 'admin_tier_0' return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and coalesce(n0.properties ->> 'system_tags', '')::text like '%admin_tier_0%') +select s0.n0 as n +from s0; + +-- case: match (n:NodeKind1) where coalesce(n.a, n.b, 1) = 1 return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and coalesce((n0.properties -> 'a')::int8, (n0.properties -> 'b')::int8, 1)::int8 = 1) +select s0.n0 as n +from s0; + +-- case: match (n:NodeKind1) where coalesce(n.a, n.b) = 1 return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and coalesce(n0.properties -> 'a', n0.properties -> 'b')::int8 = 1) +select s0.n0 as n +from s0; + +-- case: match (n:NodeKind1) where 1 = coalesce(n.a, n.b) return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and 1 = coalesce(n0.properties -> 'a', n0.properties -> 'b')::int8) +select s0.n0 as n +from s0; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql index 4cb3465067..adda9c1bea 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql @@ -185,7 +185,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])), + where e0.kind_id = any (array [3]::int2[])), s1 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, path) as (select e1.start_id, e1.end_id, 1, @@ -194,7 +194,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite array [e1.id] from s0 join edge e1 - on e1.kind_id = any (array [11]::int2[]) and (s0.n1).id = e1.start_id + on e1.kind_id = any (array [3]::int2[]) and (s0.n1).id = e1.start_id join node n2 on n2.id = e1.end_id union select ex0.root_id, @@ -234,7 +234,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite join node n1 on n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and not upper(n1.properties ->> 'operatingsystem')::text like '%SERVER%' and n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])) + where e0.kind_id = any (array [3]::int2[])) select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p from s0 limit 1000; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql index 87da96d391..2c2efc16bb 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql @@ -248,7 +248,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from s0, edge e1 join node n2 on n2.id = e1.end_id - where e1.kind_id = any (array [11, 12]::int2[]) + where e1.kind_id = any (array [3, 4]::int2[]) and (s0.n1).id = e1.start_id), s2 as (with recursive ex1(root_id, next_id, depth, satisfied, is_cycle, path) as (select e2.start_id, e2.end_id, @@ -306,7 +306,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat array [1]::int2[] and n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[]) + where e0.kind_id = any (array [3]::int2[]) union select ex0.root_id, e0.end_id, @@ -322,7 +322,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle - and e0.kind_id = any (array [11]::int2[])) + and e0.kind_id = any (array [3]::int2[])) select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 @@ -396,7 +396,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat 'objectid' like '%1234' and n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11, 12]::int2[]) + where e0.kind_id = any (array [3, 4]::int2[]) union select ex0.root_id, e0.end_id, @@ -412,7 +412,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle - and e0.kind_id = any (array [11, 12]::int2[])) + and e0.kind_id = any (array [3, 4]::int2[])) select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 @@ -440,7 +440,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat array [1]::int2[] and n0.id = e0.end_id join node n1 on n1.id = e0.start_id - where e0.kind_id = any (array [11, 12]::int2[]) + where e0.kind_id = any (array [3, 4]::int2[]) union select ex0.root_id, e0.end_id, @@ -453,7 +453,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat join node n1 on n1.id = e0.start_id where ex0.depth < 5 and not ex0.is_cycle - and e0.kind_id = any (array [11, 12]::int2[])) + and e0.kind_id = any (array [3, 4]::int2[])) select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 @@ -482,7 +482,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat array [1]::int2[] and n0.id = e0.end_id join node n1 on n1.id = e0.start_id - where e0.kind_id = any (array [11, 12]::int2[]) + where e0.kind_id = any (array [3, 4]::int2[]) union select ex0.root_id, e0.end_id, @@ -495,7 +495,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat join node n1 on n1.id = e0.start_id where ex0.depth < 5 and not ex0.is_cycle - and e0.kind_id = any (array [11, 12]::int2[])) + and e0.kind_id = any (array [3, 4]::int2[])) select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 @@ -516,7 +516,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from s0 join edge e1 on e1.kind_id = any - (array [11, 12]::int2[]) and + (array [3, 4]::int2[]) and (s0.n1).id = e1.end_id join node n2 on n2.id = e1.start_id union @@ -555,8 +555,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat 1, n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and - n1.properties ->> - 'name' similar to + n1.properties ->> 'name' ~ '(?i)Global Administrator.*|User Administrator.*|Cloud Application Administrator.*|Authentication Policy Administrator.*|Exchange Administrator.*|Helpdesk Administrator.*|Privileged Authentication Administrator.*', e0.start_id = e0.end_id, array [e0.id] @@ -566,15 +565,14 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat array [1]::int2[] and n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11, 12]::int2[]) + where e0.kind_id = any (array [3, 4]::int2[]) union select ex0.root_id, e0.end_id, ex0.depth + 1, n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and - n1.properties ->> - 'name' similar to + n1.properties ->> 'name' ~ '(?i)Global Administrator.*|User Administrator.*|Cloud Application Administrator.*|Authentication Policy Administrator.*|Exchange Administrator.*|Helpdesk Administrator.*|Privileged Authentication Administrator.*', e0.id = any (ex0.path), ex0.path || e0.id @@ -583,7 +581,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle - and e0.kind_id = any (array [11, 12]::int2[])) + and e0.kind_id = any (array [3, 4]::int2[])) select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) from edge e0 diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql index b7035887f1..d8fc95bc8f 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql @@ -53,7 +53,7 @@ from s0; -- case: match p=shortestPath((n:NodeKind1)-[:EdgeKind1*1..]->(m)) where 'admin_tier_0' in split(m.system_tags, ' ') and n.objectid ends with '-513' and n<>m return p limit 1000 -- cypher_params: {} --- pgsql_params: {"pi0":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select e0.start_id, e0.end_id, 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.start_id = e0.end_id, array [e0.id] from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.properties ->> 'objectid' like '%-513' and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [11]::int2[]);", "pi1":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select ex0.root_id, e0.end_id, ex0.depth + 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.id = any (ex0.path), ex0.path || e0.id from pathspace ex0 join edge e0 on e0.start_id = ex0.next_id join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle and e0.kind_id = any (array [11]::int2[]);"} +-- pgsql_params: {"pi0":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select e0.start_id, e0.end_id, 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.start_id = e0.end_id, array [e0.id] from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.properties ->> 'objectid' like '%-513' and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3]::int2[]);", "pi1":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select ex0.root_id, e0.end_id, ex0.depth + 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.id = any (ex0.path), ex0.path || e0.id from pathspace ex0 join edge e0 on e0.start_id = ex0.next_id join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle and e0.kind_id = any (array [3]::int2[]);"} with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) as (select * from asp_harness(@pi0::text, @pi1::text, 5)) select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql index 21cb8fbc11..00089b8f4c 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql @@ -66,7 +66,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])) + where e0.kind_id = any (array [3]::int2[])) select count(s0.e0)::int8 as the_count from s0; @@ -81,7 +81,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite join node n1 on n1.id = @pi0::float8 and n1.id = e0.end_id where not (n0.properties ->> 'objectid' like '%' || @pi2::text or n1.properties ->> 'objectid' like '%' || @pi3::text) - and (e0.kind_id = any (array [11]::int2[]) or e0.kind_id = any (array [12]::int2[]))) + and (e0.kind_id = any (array [3]::int2[]) or e0.kind_id = any (array [4]::int2[]))) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0; @@ -209,7 +209,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.end_id join node n1 on n1.id = e0.start_id - where e0.kind_id = any (array [11, 12]::int2[])) + where e0.kind_id = any (array [3, 4]::int2[])) select (s0.n0).properties -> 'name', (s0.n1).properties -> 'name' from s0; @@ -220,7 +220,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11, 12]::int2[])), + where e0.kind_id = any (array [3, 4]::int2[])), s1 as (select s0.e0 as e0, s0.n0 as n0, s0.n1 as n1, @@ -229,7 +229,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from s0, edge e1 join node n2 on n2.id = e1.end_id - where e1.kind_id = any (array [11]::int2[]) + where e1.kind_id = any (array [3]::int2[]) and (s0.n1).id = e1.start_id) select (s1.n0).properties -> 'name' as s_name, (s1.n1).properties -> 'name' as e_name from s1; @@ -241,7 +241,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.id = e0.start_id join node n1 on n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and n1.id = e0.end_id - where e0.kind_id = any (array [11, 12]::int2[])) + where e0.kind_id = any (array [3, 4]::int2[])) select (s0.n0).properties -> 'name', (s0.n1).properties -> 'name' from s0; @@ -252,7 +252,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])) + where e0.kind_id = any (array [3]::int2[])) select s0.n0 as s from s0 where (with s1 as (select s0.e0 as e0, @@ -275,7 +275,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite join node n0 on not (coalesce(n0.properties ->> 'system_tags', '')::text like '%admin_tier_0%') and n0.id = e0.start_id join node n1 on n1.id = 1 and n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])) + where e0.kind_id = any (array [3]::int2[])) select (s0.n0).id, (s0.n0).kind_ids, (s0.e0).id, (s0.e0).kind_id from s0; @@ -288,7 +288,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite lower(n0.properties ->> 'name')::text like 'test' and n0.id = e0.start_id join node n1 on n1.id = any (array [1, 2]::int8[]) and n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])) + where e0.kind_id = any (array [3]::int2[])) select s0.e0 as r from s0 limit 1; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/update.sql b/packages/go/cypher/models/pgsql/test/translation_cases/update.sql index a40e7f598b..b4488a0f1e 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/update.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/update.sql @@ -146,7 +146,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where e0.kind_id = any (array [11]::int2[])), + where e0.kind_id = any (array [3]::int2[])), s1 as (update edge e1 set properties = e1.properties || jsonb_build_object('visited', true)::jsonb from s0 where (s0.e0).id = e1.id returning (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e0, s0.n0 as n0, s0.n1 as n1) diff --git a/packages/go/cypher/models/pgsql/translate/expression.go b/packages/go/cypher/models/pgsql/translate/expression.go index 1fdd24a364..d3071a474b 100644 --- a/packages/go/cypher/models/pgsql/translate/expression.go +++ b/packages/go/cypher/models/pgsql/translate/expression.go @@ -170,7 +170,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy // If neither side has specific type information then check the operator to see if it implies some type // hinting before resorting to inference switch expression.Operator { - case pgsql.OperatorStartsWith, pgsql.OperatorContains, pgsql.OperatorEndsWith: + case pgsql.OperatorCypherStartsWith, pgsql.OperatorCypherContains, pgsql.OperatorCypherEndsWith: // String operations imply the operands must be text return pgsql.Text, nil @@ -320,7 +320,7 @@ func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType) } - case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch: + case pgsql.OperatorCypherStartsWith, pgsql.OperatorCypherEndsWith, pgsql.OperatorCypherContains, pgsql.OperatorRegexMatch: expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, pgsql.Text) default: @@ -341,7 +341,7 @@ func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, arrayType) } - case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch: + case pgsql.OperatorCypherStartsWith, pgsql.OperatorCypherEndsWith, pgsql.OperatorCypherContains, pgsql.OperatorRegexMatch: expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, pgsql.Text) default: @@ -353,6 +353,34 @@ func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { return nil } +func applyTypeFunctionCallTypeHints(expression *pgsql.BinaryExpression) error { + switch typedLOperand := expression.LOperand.(type) { + case pgsql.FunctionCall: + if !typedLOperand.CastType.IsKnown() { + if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil { + return err + } else { + typedLOperand.CastType = rOperandTypeHint + expression.LOperand = typedLOperand + } + } + } + + switch typedROperand := expression.ROperand.(type) { + case pgsql.FunctionCall: + if !typedROperand.CastType.IsKnown() { + if lOperandTypeHint, err := InferExpressionType(expression.LOperand); err != nil { + return err + } else { + typedROperand.CastType = lOperandTypeHint + expression.ROperand = typedROperand + } + } + } + + return nil +} + func applyBinaryExpressionTypeHints(expression *pgsql.BinaryExpression) error { switch expression.Operator { case pgsql.OperatorPropertyLookup: @@ -361,7 +389,11 @@ func applyBinaryExpressionTypeHints(expression *pgsql.BinaryExpression) error { return nil } - return rewritePropertyLookupOperands(expression) + if err := rewritePropertyLookupOperands(expression); err != nil { + return err + } + + return applyTypeFunctionCallTypeHints(expression) } type Builder struct { @@ -651,7 +683,7 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato } switch operator { - case pgsql.OperatorContains: + case pgsql.OperatorCypherContains: newExpression.Operator = pgsql.OperatorLike switch typedLOperand := newExpression.LOperand.(type) { @@ -726,11 +758,11 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato s.Push(newExpression) - case pgsql.OperatorRegexMatch: - newExpression.Operator = pgsql.OperatorSimilarTo + case pgsql.OperatorCypherRegexMatch: + newExpression.Operator = pgsql.OperatorRegexMatch s.Push(newExpression) - case pgsql.OperatorStartsWith: + case pgsql.OperatorCypherStartsWith: newExpression.Operator = pgsql.OperatorLike switch typedLOperand := newExpression.LOperand.(type) { @@ -793,7 +825,7 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato s.Push(newExpression) - case pgsql.OperatorEndsWith: + case pgsql.OperatorCypherEndsWith: newExpression.Operator = pgsql.OperatorLike switch typedLOperand := newExpression.LOperand.(type) { diff --git a/packages/go/cypher/models/pgsql/translate/format.go b/packages/go/cypher/models/pgsql/translate/format.go index 148f54f21c..4bb46021cb 100644 --- a/packages/go/cypher/models/pgsql/translate/format.go +++ b/packages/go/cypher/models/pgsql/translate/format.go @@ -18,6 +18,7 @@ package translate import ( "bytes" + "context" "github.com/specterops/bloodhound/cypher/models/cypher" cypherFormat "github.com/specterops/bloodhound/cypher/models/cypher/format" @@ -29,7 +30,7 @@ func Translated(translation Result) (string, error) { return format.Statement(translation.Statement, format.NewOutputBuilder()) } -func FromCypher(regularQuery *cypher.RegularQuery, kindMapper pgsql.KindMapper, stripLiterals bool) (format.Formatted, error) { +func FromCypher(ctx context.Context, regularQuery *cypher.RegularQuery, kindMapper pgsql.KindMapper, stripLiterals bool) (format.Formatted, error) { var ( output = &bytes.Buffer{} emitter = cypherFormat.NewCypherEmitter(stripLiterals) @@ -43,7 +44,7 @@ func FromCypher(regularQuery *cypher.RegularQuery, kindMapper pgsql.KindMapper, output.WriteString("\n") - if translation, err := Translate(regularQuery, kindMapper); err != nil { + if translation, err := Translate(ctx, regularQuery, kindMapper, nil); err != nil { return format.Formatted{}, err } else if sqlQuery, err := format.Statement(translation.Statement, format.NewOutputBuilder()); err != nil { return format.Formatted{}, err diff --git a/packages/go/cypher/models/pgsql/translate/node.go b/packages/go/cypher/models/pgsql/translate/node.go index 4a6684ab42..19ae8a86e8 100644 --- a/packages/go/cypher/models/pgsql/translate/node.go +++ b/packages/go/cypher/models/pgsql/translate/node.go @@ -18,7 +18,6 @@ package translate import ( "fmt" - "strings" "github.com/specterops/bloodhound/cypher/models" cypher "github.com/specterops/bloodhound/cypher/models/cypher" @@ -48,8 +47,8 @@ func (s *Translator) translateNodePattern(scope *Scope, nodePattern *cypher.Node } if len(nodePattern.Kinds) > 0 { - if kindIDs, missingKinds := s.kindMapper.MapKinds(nodePattern.Kinds); len(missingKinds) > 0 { - s.SetErrorf("unable to map kinds: %s", strings.Join(missingKinds.Strings(), ", ")) + if kindIDs, err := s.kindMapper.MapKinds(s.ctx, nodePattern.Kinds); err != nil { + s.SetError(fmt.Errorf("failed to translate kinds: %w", err)) } else if kindIDsLiteral, err := pgsql.AsLiteral(kindIDs); err != nil { s.SetError(err) } else { diff --git a/packages/go/cypher/models/pgsql/translate/relationship.go b/packages/go/cypher/models/pgsql/translate/relationship.go index f0fc2a441a..a0586207c6 100644 --- a/packages/go/cypher/models/pgsql/translate/relationship.go +++ b/packages/go/cypher/models/pgsql/translate/relationship.go @@ -17,7 +17,7 @@ package translate import ( - "strings" + "fmt" "github.com/specterops/bloodhound/cypher/models" cypher "github.com/specterops/bloodhound/cypher/models/cypher" @@ -50,8 +50,8 @@ func (s *Translator) translateRelationshipPattern(scope *Scope, relationshipPatt // Capture the kind matchers for this relationship pattern if len(relationshipPattern.Kinds) > 0 { - if kindIDs, missingKinds := s.kindMapper.MapKinds(relationshipPattern.Kinds); len(missingKinds) > 0 { - s.SetErrorf("unable to map kinds: %s", strings.Join(missingKinds.Strings(), ", ")) + if kindIDs, err := s.kindMapper.MapKinds(s.ctx, relationshipPattern.Kinds); err != nil { + s.SetError(fmt.Errorf("failed to translate kinds: %w", err)) } else if kindIDsLiteral, err := pgsql.AsLiteral(kindIDs); err != nil { s.SetError(err) } else { diff --git a/packages/go/cypher/models/pgsql/translate/translation.go b/packages/go/cypher/models/pgsql/translate/translation.go index 5bc71308a9..7b109e1209 100644 --- a/packages/go/cypher/models/pgsql/translate/translation.go +++ b/packages/go/cypher/models/pgsql/translate/translation.go @@ -18,7 +18,6 @@ package translate import ( "fmt" - "strings" "github.com/specterops/bloodhound/cypher/models" "github.com/specterops/bloodhound/cypher/models/cypher" @@ -30,7 +29,7 @@ func translateCypherAssignmentOperator(operator cypher.AssignmentOperator) (pgsq case cypher.OperatorAssignment: return pgsql.OperatorAssignment, nil case cypher.OperatorLabelAssignment: - return pgsql.OperatorLabelAssignment, nil + return pgsql.OperatorKindAssignment, nil default: return pgsql.UnsetOperator, fmt.Errorf("unsupported assignment operator %s", operator) } @@ -146,7 +145,7 @@ func (s *Translator) translateSetItem(setItem *cypher.SetItem) error { return s.mutations.AddPropertyAssignment(s.query.Scope, leftPropertyLookup, operator, rightOperand) } - case pgsql.OperatorLabelAssignment: + case pgsql.OperatorKindAssignment: if rightOperand, err := s.treeTranslator.Pop(); err != nil { return err } else if leftOperand, err := s.treeTranslator.Pop(); err != nil { @@ -232,13 +231,67 @@ func (s *Translator) translateDateTimeFunctionCall(cypherFunc *cypher.FunctionIn return nil } +func (s *Translator) translateCoalesceFunction(functionInvocation *cypher.FunctionInvocation) error { + if numArgs := functionInvocation.NumArguments(); numArgs == 0 { + s.SetError(fmt.Errorf("expected at least one argument for cypher function: %s", functionInvocation.Name)) + } else { + var ( + arguments = make([]pgsql.Expression, numArgs) + expectedType = pgsql.UnsetDataType + ) + + // This loop is used to pop off the coalesce function arguments in the intended order (since they're + // pushed onto the translator stack). + for idx := range functionInvocation.Arguments { + if argument, err := s.treeTranslator.Pop(); err != nil { + return err + } else { + arguments[numArgs-idx-1] = argument + } + } + + // Find and validate types of the arguments + for _, argument := range arguments { + if argumentType, err := InferExpressionType(argument); err != nil { + return err + } else if argumentType.IsKnown() { + // If the expected type isn't known yet then assign the known inferred type to it + if !expectedType.IsKnown() { + expectedType = argumentType + } else if expectedType != argumentType { + // All other inferrable argument types must match the first inferred type encountered + return fmt.Errorf("types in coalesce function must match %s but got %s", expectedType, argumentType) + } + } + } + + if expectedType.IsKnown() { + // Rewrite any property lookup operators now that we have some type information + for idx, argument := range arguments { + if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + arguments[idx] = rewritePropertyLookupOperator(propertyLookup, expectedType) + } + } + } + + // Translate the function call to the expected SQL form + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionCoalesce, + Parameters: arguments, + CastType: expectedType, + }) + } + + return nil +} + func (s *Translator) translateKindMatcher(kindMatcher *cypher.KindMatcher) error { if variable, isVariable := kindMatcher.Reference.(*cypher.Variable); !isVariable { return fmt.Errorf("expected variable for kind matcher reference but found type: %T", kindMatcher.Reference) } else if binding, resolved := s.query.Scope.LookupString(variable.Symbol); !resolved { return fmt.Errorf("unable to find identifier %s", variable.Symbol) - } else if kindIDs, missingKinds := s.kindMapper.MapKinds(kindMatcher.Kinds); len(missingKinds) > 0 { - return fmt.Errorf("unable to map kinds: %s", strings.Join(missingKinds.Strings(), ", ")) + } else if kindIDs, err := s.kindMapper.MapKinds(s.ctx, kindMatcher.Kinds); err != nil { + s.SetError(fmt.Errorf("failed to translate kinds: %w", err)) } else if kindIDsLiteral, err := pgsql.AsLiteral(kindIDs); err != nil { return err } else { diff --git a/packages/go/cypher/models/pgsql/translate/translator.go b/packages/go/cypher/models/pgsql/translate/translator.go index 02fa92daca..8c15f52f9c 100644 --- a/packages/go/cypher/models/pgsql/translate/translator.go +++ b/packages/go/cypher/models/pgsql/translate/translator.go @@ -17,6 +17,7 @@ package translate import ( + "context" "fmt" "strings" @@ -68,6 +69,7 @@ func (s State) String() string { type Translator struct { walk.HierarchicalVisitor[cypher.SyntaxNode] + ctx context.Context kindMapper pgsql.KindMapper translation Result state []State @@ -80,12 +82,17 @@ type Translator struct { query *Query } -func NewTranslator(kindMapper pgsql.KindMapper) *Translator { +func NewTranslator(ctx context.Context, kindMapper pgsql.KindMapper, parameters map[string]any) *Translator { + if parameters == nil { + parameters = map[string]any{} + } + return &Translator{ HierarchicalVisitor: walk.NewComposableHierarchicalVisitor[cypher.SyntaxNode](), translation: Result{ - Parameters: map[string]any{}, + Parameters: parameters, }, + ctx: ctx, kindMapper: kindMapper, treeTranslator: NewExpressionTreeTranslator(), properties: map[string]pgsql.Expression{}, @@ -482,7 +489,9 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { switch formattedName { case cypher.IdentityFunction: - if referenceArgument, err := PopFromBuilderAs[pgsql.Identifier](s.treeTranslator); err != nil { + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if referenceArgument, err := PopFromBuilderAs[pgsql.Identifier](s.treeTranslator); err != nil { s.SetError(err) } else { s.treeTranslator.Push(pgsql.CompoundIdentifier{referenceArgument, pgsql.ColumnID}) @@ -509,7 +518,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.EdgeTypeFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -520,7 +529,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.NodeLabelsFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -531,7 +540,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.CountFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -572,7 +581,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.ToLowerFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -590,7 +599,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.ListSizeFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -615,7 +624,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.ToUpperFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -633,7 +642,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.ToStringFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -642,7 +651,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case cypher.ToIntegerFunction: - if typedExpression.NumArguments() > 1 { + if typedExpression.NumArguments() != 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) } else if argument, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -650,6 +659,11 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Int8)) } + case cypher.CoalesceFunction: + if err := s.translateCoalesceFunction(typedExpression); err != nil { + s.SetError(err) + } + default: s.SetErrorf("unknown cypher function: %s", typedExpression.Name) } @@ -848,8 +862,8 @@ type Result struct { Parameters map[string]any } -func Translate(cypherQuery *cypher.RegularQuery, kindMapper pgsql.KindMapper) (Result, error) { - translator := NewTranslator(kindMapper) +func Translate(ctx context.Context, cypherQuery *cypher.RegularQuery, kindMapper pgsql.KindMapper, parameters map[string]any) (Result, error) { + translator := NewTranslator(ctx, kindMapper, parameters) if err := walk.WalkCypher(cypherQuery, translator); err != nil { return Result{}, err diff --git a/packages/go/cypher/models/pgsql/translate/translator_test.go b/packages/go/cypher/models/pgsql/translate/translator_test.go index af5fd15e1a..724211293f 100644 --- a/packages/go/cypher/models/pgsql/translate/translator_test.go +++ b/packages/go/cypher/models/pgsql/translate/translator_test.go @@ -20,6 +20,8 @@ import ( "fmt" "testing" + "github.com/specterops/bloodhound/dawgs/drivers/pg/pgutil" + "github.com/specterops/bloodhound/cypher/models/pgsql" "github.com/specterops/bloodhound/cypher/models/pgsql/test" "github.com/specterops/bloodhound/dawgs/graph" @@ -33,12 +35,13 @@ var ( ) func newKindMapper() pgsql.KindMapper { - mapper := test.NewInMemoryKindMapper() + mapper := pgutil.NewInMemoryKindMapper() - mapper.Put(NodeKind1, 1) - mapper.Put(NodeKind2, 2) - mapper.Put(EdgeKind1, 11) - mapper.Put(EdgeKind2, 12) + // This is here to make SQL output a little more predictable for test cases + mapper.Put(NodeKind1) + mapper.Put(NodeKind2) + mapper.Put(EdgeKind1) + mapper.Put(EdgeKind2) return mapper } diff --git a/packages/go/cypher/models/pgsql/translate/update.go b/packages/go/cypher/models/pgsql/translate/update.go index 7566bf65d2..6f5fc89ce9 100644 --- a/packages/go/cypher/models/pgsql/translate/update.go +++ b/packages/go/cypher/models/pgsql/translate/update.go @@ -105,8 +105,8 @@ func (s *Translator) buildUpdates(scope *Scope) error { } if len(identifierMutation.KindAssignments) > 0 { - if kindIDs, missing := s.kindMapper.MapKinds(identifierMutation.KindAssignments); len(missing) > 0 { - return fmt.Errorf("unable to map kinds: %v", missing) + if kindIDs, err := s.kindMapper.MapKinds(s.ctx, identifierMutation.KindAssignments); err != nil { + s.SetError(fmt.Errorf("failed to translate kinds: %w", err)) } else { arrayLiteral := pgsql.ArrayLiteral{ Values: make([]pgsql.Expression, len(kindIDs)), @@ -122,8 +122,8 @@ func (s *Translator) buildUpdates(scope *Scope) error { } if len(identifierMutation.KindRemovals) > 0 { - if kindIDs, missing := s.kindMapper.MapKinds(identifierMutation.KindRemovals); len(missing) > 0 { - return fmt.Errorf("unable to map kinds: %v", missing) + if kindIDs, err := s.kindMapper.MapKinds(s.ctx, identifierMutation.KindRemovals); err != nil { + s.SetError(fmt.Errorf("failed to translate kinds: %w", err)) } else { arrayLiteral := pgsql.ArrayLiteral{ Values: make([]pgsql.Expression, len(kindIDs)), diff --git a/packages/go/cypher/models/pgsql/visualization/visualizer_test.go b/packages/go/cypher/models/pgsql/visualization/visualizer_test.go index 07687ac256..5ee79d401b 100644 --- a/packages/go/cypher/models/pgsql/visualization/visualizer_test.go +++ b/packages/go/cypher/models/pgsql/visualization/visualizer_test.go @@ -18,9 +18,10 @@ package visualization import ( "bytes" + "context" "testing" - "github.com/specterops/bloodhound/cypher/models/pgsql/test" + "github.com/specterops/bloodhound/dawgs/drivers/pg/pgutil" "github.com/specterops/bloodhound/cypher/frontend" "github.com/specterops/bloodhound/cypher/models/pgsql/translate" @@ -28,12 +29,12 @@ import ( ) func TestGraphToPUMLDigraph(t *testing.T) { - kindMapper := test.NewInMemoryKindMapper() + kindMapper := pgutil.NewInMemoryKindMapper() regularQuery, err := frontend.ParseCypher(frontend.NewContext(), "match (s), (e) where s.name = s.other + 1 / s.last return s") require.Nil(t, err) - translation, err := translate.Translate(regularQuery, kindMapper) + translation, err := translate.Translate(context.Background(), regularQuery, kindMapper, nil) require.Nil(t, err) graph, err := SQLToDigraph(translation.Statement) diff --git a/packages/go/dawgs/drivers/pg/batch.go b/packages/go/dawgs/drivers/pg/batch.go index bed59eac00..dfa3fb6872 100644 --- a/packages/go/dawgs/drivers/pg/batch.go +++ b/packages/go/dawgs/drivers/pg/batch.go @@ -162,8 +162,8 @@ func (s *batch) flushNodeCreateBufferWithIDs() error { for idx, nextNode := range s.nodeCreateBuffer { nodeIDs[idx] = nextNode.ID.Uint64() - if mappedKindIDs, missingKinds := s.schemaManager.MapKinds(nextNode.Kinds); len(missingKinds) > 0 { - return fmt.Errorf("unable to map kinds %v", missingKinds) + if mappedKindIDs, err := s.schemaManager.AssertKinds(s.ctx, nextNode.Kinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) } else { kindIDSlices[idx] = kindIDEncoder.Encode(mappedKindIDs) } @@ -196,8 +196,8 @@ func (s *batch) flushNodeCreateBufferWithoutIDs() error { ) for idx, nextNode := range s.nodeCreateBuffer { - if mappedKindIDs, missingKinds := s.schemaManager.MapKinds(nextNode.Kinds); len(missingKinds) > 0 { - return fmt.Errorf("unable to map kinds %v", missingKinds) + if mappedKindIDs, err := s.schemaManager.AssertKinds(s.ctx, nextNode.Kinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) } else { kindIDSlices[idx] = kindIDEncoder.Encode(mappedKindIDs) } @@ -222,7 +222,7 @@ func (s *batch) flushNodeCreateBufferWithoutIDs() error { func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { parameters := NewNodeUpsertParameters(len(updates.Updates)) - if err := parameters.AppendAll(updates, s.schemaManager, s.kindIDEncoder); err != nil { + if err := parameters.AppendAll(s.ctx, updates, s.schemaManager, s.kindIDEncoder); err != nil { return err } @@ -284,11 +284,11 @@ func (s *NodeUpsertParameters) Format(graphTarget model.Graph) []any { } } -func (s *NodeUpsertParameters) Append(update *sql.NodeUpdate, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { +func (s *NodeUpsertParameters) Append(ctx context.Context, update *sql.NodeUpdate, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { s.IDFutures = append(s.IDFutures, update.IDFuture) - if mappedKindIDs, missingKinds := schemaManager.MapKinds(update.Node.Kinds); len(missingKinds) > 0 { - return fmt.Errorf("unable to map kinds %v", missingKinds) + if mappedKindIDs, err := schemaManager.AssertKinds(ctx, update.Node.Kinds); err != nil { + return fmt.Errorf("unable to map kinds %w", err) } else { s.KindIDSlices = append(s.KindIDSlices, kindIDEncoder.Encode(mappedKindIDs)) } @@ -302,9 +302,9 @@ func (s *NodeUpsertParameters) Append(update *sql.NodeUpdate, schemaManager *Sch return nil } -func (s *NodeUpsertParameters) AppendAll(updates *sql.NodeUpdateBatch, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { +func (s *NodeUpsertParameters) AppendAll(ctx context.Context, updates *sql.NodeUpdateBatch, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { for _, nextUpdate := range updates.Updates { - if err := s.Append(nextUpdate, schemaManager, kindIDEncoder); err != nil { + if err := s.Append(ctx, nextUpdate, schemaManager, kindIDEncoder); err != nil { return err } } @@ -338,12 +338,12 @@ func (s *RelationshipUpdateByParameters) Format(graphTarget model.Graph) []any { } } -func (s *RelationshipUpdateByParameters) Append(update *sql.RelationshipUpdate, schemaManager *SchemaManager) error { +func (s *RelationshipUpdateByParameters) Append(ctx context.Context, update *sql.RelationshipUpdate, schemaManager *SchemaManager) error { s.StartIDs = append(s.StartIDs, update.StartID.Value) s.EndIDs = append(s.EndIDs, update.EndID.Value) - if mappedKindID, mapped := schemaManager.MapKind(update.Relationship.Kind); !mapped { - return fmt.Errorf("unable to map kind %s", update.Relationship.Kind) + if mappedKindID, err := schemaManager.MapKind(ctx, update.Relationship.Kind); err != nil { + return err } else { s.KindIDs = append(s.KindIDs, mappedKindID) } @@ -356,9 +356,9 @@ func (s *RelationshipUpdateByParameters) Append(update *sql.RelationshipUpdate, return nil } -func (s *RelationshipUpdateByParameters) AppendAll(updates *sql.RelationshipUpdateBatch, schemaManager *SchemaManager) error { +func (s *RelationshipUpdateByParameters) AppendAll(ctx context.Context, updates *sql.RelationshipUpdateBatch, schemaManager *SchemaManager) error { for _, nextUpdate := range updates.Updates { - if err := s.Append(nextUpdate, schemaManager); err != nil { + if err := s.Append(ctx, nextUpdate, schemaManager); err != nil { return err } } @@ -373,7 +373,7 @@ func (s *batch) flushRelationshipUpdateByBuffer(updates *sql.RelationshipUpdateB parameters := NewRelationshipUpdateByParameters(len(updates.Updates)) - if err := parameters.AppendAll(updates, s.schemaManager); err != nil { + if err := parameters.AppendAll(s.ctx, updates, s.schemaManager); err != nil { return err } @@ -454,7 +454,7 @@ func (s *relationshipCreateBatchBuilder) Build() (*relationshipCreateBatch, erro return s.relationshipUpdateBatch, s.relationshipUpdateBatch.EncodeProperties(s.edgePropertiesBatch) } -func (s *relationshipCreateBatchBuilder) Add(kindMapper KindMapper, edge *graph.Relationship) error { +func (s *relationshipCreateBatchBuilder) Add(ctx context.Context, kindMapper KindMapper, edge *graph.Relationship) error { keyBuilder := strings.Builder{} keyBuilder.WriteString(edge.StartID.String()) @@ -473,8 +473,8 @@ func (s *relationshipCreateBatchBuilder) Add(kindMapper KindMapper, edge *graph. edgeProperties = edge.Properties.Clone() ) - if edgeKindID, hasKind := kindMapper.MapKind(edge.Kind); !hasKind { - return fmt.Errorf("unable to map kind %s", edge.Kind) + if edgeKindID, err := kindMapper.MapKind(ctx, edge.Kind); err != nil { + return err } else { s.relationshipUpdateBatch.Add(startID, endID, edgeKindID) } @@ -492,7 +492,7 @@ func (s *batch) flushRelationshipCreateBuffer() error { batchBuilder := newRelationshipCreateBatchBuilder(len(s.relationshipCreateBuffer)) for _, nextRel := range s.relationshipCreateBuffer { - if err := batchBuilder.Add(s.schemaManager, nextRel); err != nil { + if err := batchBuilder.Add(s.ctx, s.schemaManager, nextRel); err != nil { return err } } diff --git a/packages/go/dawgs/drivers/pg/driver.go b/packages/go/dawgs/drivers/pg/driver.go index 1ef5831f51..6a8f0620f8 100644 --- a/packages/go/dawgs/drivers/pg/driver.go +++ b/packages/go/dawgs/drivers/pg/driver.go @@ -59,9 +59,7 @@ type Driver struct { } func (s *Driver) SetDefaultGraph(ctx context.Context, graphSchema graph.Graph) error { - return s.ReadTransaction(ctx, func(tx graph.Transaction) error { - return s.schemaManager.SetDefaultGraph(tx, graphSchema) - }) + return s.schemaManager.SetDefaultGraph(ctx, graphSchema) } func (s *Driver) KindMapper() KindMapper { @@ -176,19 +174,19 @@ func (s *Driver) FetchSchema(ctx context.Context) (graph.Schema, error) { } func (s *Driver) AssertSchema(ctx context.Context, schema graph.Schema) error { - if err := s.WriteTransaction(ctx, func(tx graph.Transaction) error { - if err := s.schemaManager.AssertSchema(tx, schema); err != nil { - return err - } else if schema.DefaultGraph.Name != "" { - return s.schemaManager.AssertDefaultGraph(tx, schema.DefaultGraph) - } + // Resetting the pool must be done on every schema assertion as composite types may have changed OIDs + defer s.pool.Reset() - return nil - }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)); err != nil { + // Assert that the base graph schema exists and has a matching schema definition + if err := s.schemaManager.AssertSchema(ctx, schema); err != nil { return err - } else { - // Resetting the pool must be done on every schema assertion as composite types may have changed OIDs - s.pool.Reset() + } + + if schema.DefaultGraph.Name != "" { + // There's a default graph defined. Assert that it exists and has a matching schema + if err := s.schemaManager.AssertDefaultGraph(ctx, schema.DefaultGraph); err != nil { + return err + } } return nil diff --git a/packages/go/dawgs/drivers/pg/manager.go b/packages/go/dawgs/drivers/pg/manager.go index f005f11c84..759167013a 100644 --- a/packages/go/dawgs/drivers/pg/manager.go +++ b/packages/go/dawgs/drivers/pg/manager.go @@ -17,7 +17,10 @@ package pg import ( + "context" "errors" + "fmt" + "strings" "sync" "github.com/jackc/pgx/v5" @@ -27,15 +30,16 @@ import ( ) type KindMapper interface { - MapKindID(kindID int16) (graph.Kind, bool) - MapKindIDs(kindIDs ...int16) (graph.Kinds, []int16) - MapKind(kind graph.Kind) (int16, bool) - MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) - AssertKinds(tx graph.Transaction, kinds graph.Kinds) ([]int16, error) + MapKindID(ctx context.Context, kindID int16) (graph.Kind, error) + MapKindIDs(ctx context.Context, kindIDs ...int16) (graph.Kinds, error) + MapKind(ctx context.Context, kind graph.Kind) (int16, error) + MapKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) + AssertKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) } type SchemaManager struct { defaultGraph model.Graph + database graph.Database hasDefaultGraph bool graphs map[string]model.Graph kindsByID map[graph.Kind]int16 @@ -43,8 +47,9 @@ type SchemaManager struct { lock *sync.RWMutex } -func NewSchemaManager() *SchemaManager { +func NewSchemaManager(database graph.Database) *SchemaManager { return &SchemaManager{ + database: database, hasDefaultGraph: false, graphs: map[string]model.Graph{}, kindsByID: map[graph.Kind]int16{}, @@ -67,6 +72,12 @@ func (s *SchemaManager) fetch(tx graph.Transaction) error { return nil } +func (s *SchemaManager) Fetch(ctx context.Context) error { + return s.database.WriteTransaction(ctx, func(tx graph.Transaction) error { + return s.fetch(tx) + }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)) +} + func (s *SchemaManager) defineKinds(tx graph.Transaction, kinds graph.Kinds) error { for _, kind := range kinds { if kindID, err := query.On(tx).InsertOrGetKind(kind); err != nil { @@ -97,19 +108,50 @@ func (s *SchemaManager) mapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { return ids, missingKinds } -func (s *SchemaManager) MapKind(kind graph.Kind) (int16, bool) { +func (s *SchemaManager) MapKind(ctx context.Context, kind graph.Kind) (int16, error) { s.lock.RLock() - defer s.lock.RUnlock() - id, hasID := s.kindsByID[kind] - return id, hasID + if id, hasID := s.kindsByID[kind]; hasID { + s.lock.RUnlock() + return id, nil + } + + s.lock.RUnlock() + s.lock.Lock() + defer s.lock.Unlock() + + if err := s.Fetch(ctx); err != nil { + return -1, err + } + + if id, hasID := s.kindsByID[kind]; hasID { + return id, nil + } else { + return -1, fmt.Errorf("unable to map kind: %s", kind.String()) + } } -func (s *SchemaManager) MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { +func (s *SchemaManager) MapKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) { s.lock.RLock() - defer s.lock.RUnlock() - return s.mapKinds(kinds) + if mappedKinds, missingKinds := s.mapKinds(kinds); len(missingKinds) == 0 { + s.lock.RUnlock() + return mappedKinds, nil + } + + s.lock.RUnlock() + s.lock.Lock() + defer s.lock.Unlock() + + if err := s.Fetch(ctx); err != nil { + return nil, err + } + + if mappedKinds, missingKinds := s.mapKinds(kinds); len(missingKinds) == 0 { + return mappedKinds, nil + } else { + return nil, fmt.Errorf("unable to map kinds: %s", strings.Join(missingKinds.Strings(), ", ")) + } } func (s *SchemaManager) mapKindIDs(kindIDs []int16) (graph.Kinds, []int16) { @@ -129,76 +171,109 @@ func (s *SchemaManager) mapKindIDs(kindIDs []int16) (graph.Kinds, []int16) { return kinds, missingIDs } -func (s *SchemaManager) MapKindID(kindID int16) (graph.Kind, bool) { - s.lock.RLock() - defer s.lock.RUnlock() - - kind, hasKind := s.kindIDsByKind[kindID] - return kind, hasKind -} - -func (s *SchemaManager) MapKindIDs(kindIDs ...int16) (graph.Kinds, []int16) { - s.lock.RLock() - defer s.lock.RUnlock() - - return s.mapKindIDs(kindIDs) +func (s *SchemaManager) MapKindID(ctx context.Context, kindID int16) (graph.Kind, error) { + if kindIDs, err := s.MapKindIDs(ctx, kindID); err != nil { + return nil, err + } else { + return kindIDs[0], nil + } } -func (s *SchemaManager) AssertKinds(tx graph.Transaction, kinds graph.Kinds) ([]int16, error) { - // Acquire a read-lock first to fast-pass validate if we're missing any kind definitions +func (s *SchemaManager) MapKindIDs(ctx context.Context, kindIDs ...int16) (graph.Kinds, error) { s.lock.RLock() - if kindIDs, missingKinds := s.mapKinds(kinds); len(missingKinds) == 0 { - // All kinds are defined. Release the read-lock here before returning + if kinds, missingKinds := s.mapKindIDs(kindIDs); len(missingKinds) == 0 { s.lock.RUnlock() - return kindIDs, nil + return kinds, nil } - // Release the read-lock here so that we can acquire a write-lock s.lock.RUnlock() + s.lock.Lock() + defer s.lock.Unlock() + if err := s.Fetch(ctx); err != nil { + return nil, err + } + + if kinds, missingKinds := s.mapKindIDs(kindIDs); len(missingKinds) == 0 { + return kinds, nil + } else { + return nil, fmt.Errorf("unable to map kind ids: %v", missingKinds) + } +} + +func (s *SchemaManager) assertKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) { // Acquire a write-lock and release on-exit s.lock.Lock() defer s.lock.Unlock() // We have to re-acquire the missing kinds since there's a potential for another writer to acquire the write-lock // in between release of the read-lock and acquisition of the write-lock for this operation - _, missingKinds := s.mapKinds(kinds) - - if err := s.defineKinds(tx, missingKinds); err != nil { - return nil, err + if _, missingKinds := s.mapKinds(kinds); len(missingKinds) > 0 { + if err := s.database.WriteTransaction(ctx, func(tx graph.Transaction) error { + return s.defineKinds(tx, missingKinds) + }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)); err != nil { + return nil, err + } } + // Lookup the kinds again from memory as they should now be up to date kindIDs, _ := s.mapKinds(kinds) return kindIDs, nil } -func (s *SchemaManager) SetDefaultGraph(tx graph.Transaction, schema graph.Graph) error { - // Validate the schema if the graph already exists in the database - if definition, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { - return err - } else { - s.graphs[schema.Name] = definition +func (s *SchemaManager) AssertKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) { + // Acquire a read-lock first to fast-pass validate if we're missing any kind definitions + s.lock.RLock() - s.defaultGraph = definition - s.hasDefaultGraph = true + if kindIDs, missingKinds := s.mapKinds(kinds); len(missingKinds) == 0 { + // All kinds are defined. Release the read-lock here before returning + s.lock.RUnlock() + return kindIDs, nil } - return nil + // Release the read-lock here so that we can acquire a write-lock + s.lock.RUnlock() + return s.assertKinds(ctx, kinds) } -func (s *SchemaManager) AssertDefaultGraph(tx graph.Transaction, schema graph.Graph) error { - if graphInstance, err := s.AssertGraph(tx, schema); err != nil { - return err - } else { - s.lock.Lock() - defer s.lock.Unlock() +func (s *SchemaManager) setDefaultGraph(defaultGraph model.Graph, schema graph.Graph) { + s.lock.Lock() + defer s.lock.Unlock() - s.defaultGraph = graphInstance - s.hasDefaultGraph = true + if s.hasDefaultGraph { + // Another actor has already asserted or otherwise set a default graph + return } - return nil + s.graphs[schema.Name] = defaultGraph + + s.defaultGraph = defaultGraph + s.hasDefaultGraph = true +} + +func (s *SchemaManager) SetDefaultGraph(ctx context.Context, schema graph.Graph) error { + return s.database.ReadTransaction(ctx, func(tx graph.Transaction) error { + // Validate the schema if the graph already exists in the database + if graphModel, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { + return err + } else { + s.setDefaultGraph(graphModel, schema) + return nil + } + }) +} + +func (s *SchemaManager) AssertDefaultGraph(ctx context.Context, schema graph.Graph) error { + return s.database.WriteTransaction(ctx, func(tx graph.Transaction) error { + if graphModel, err := s.AssertGraph(tx, schema); err != nil { + return err + } else { + s.setDefaultGraph(graphModel, schema) + } + + return nil + }) } func (s *SchemaManager) DefaultGraph() (model.Graph, bool) { @@ -208,6 +283,33 @@ func (s *SchemaManager) DefaultGraph() (model.Graph, bool) { return s.defaultGraph, s.hasDefaultGraph } +func (s *SchemaManager) assertGraph(tx graph.Transaction, schema graph.Graph) (model.Graph, error) { + var assertedGraph model.Graph + + // Validate the schema if the graph already exists in the database + if definition, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { + // ErrNoRows is ignored as it signifies that this graph must be created + if !errors.Is(err, pgx.ErrNoRows) { + return model.Graph{}, err + } + + if newDefinition, err := query.On(tx).CreateGraph(schema); err != nil { + return model.Graph{}, err + } else { + assertedGraph = newDefinition + } + } else if assertedDefinition, err := query.On(tx).AssertGraph(schema, definition); err != nil { + return model.Graph{}, err + } else { + // Graph existed and may have been updated + assertedGraph = assertedDefinition + } + + // Cache the graph definition and return it + s.graphs[schema.Name] = assertedGraph + return assertedGraph, nil +} + func (s *SchemaManager) AssertGraph(tx graph.Transaction, schema graph.Graph) (model.Graph, error) { // Acquire a read-lock first to fast-pass validate if we're missing the graph definitions s.lock.RLock() @@ -218,44 +320,21 @@ func (s *SchemaManager) AssertGraph(tx graph.Transaction, schema graph.Graph) (m return graphInstance, nil } - // Release the read-lock here so that we can acquire a write-lock + // Release the read-lock here so that we can acquire a write-lock next s.lock.RUnlock() - // Acquire a write-lock and create the graph definition s.lock.Lock() defer s.lock.Unlock() if graphInstance, isDefined := s.graphs[schema.Name]; isDefined { - // The graph was defined by a different actor between the read unlock and the write lock. + // The graph was defined by a different actor between the read unlock and the write lock, return it return graphInstance, nil } - // Validate the schema if the graph already exists in the database - if definition, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { - // ErrNoRows signifies that this graph must be created - if !errors.Is(err, pgx.ErrNoRows) { - return model.Graph{}, err - } - } else if assertedDefinition, err := query.On(tx).AssertGraph(schema, definition); err != nil { - return model.Graph{}, err - } else { - s.graphs[schema.Name] = assertedDefinition - return assertedDefinition, nil - } - - // Create the graph - if definition, err := query.On(tx).CreateGraph(schema); err != nil { - return model.Graph{}, err - } else { - s.graphs[schema.Name] = definition - return definition, nil - } + return s.assertGraph(tx, schema) } -func (s *SchemaManager) AssertSchema(tx graph.Transaction, schema graph.Schema) error { - s.lock.Lock() - defer s.lock.Unlock() - +func (s *SchemaManager) assertSchema(tx graph.Transaction, schema graph.Schema) error { if err := query.On(tx).CreateSchema(); err != nil { return err } @@ -280,3 +359,12 @@ func (s *SchemaManager) AssertSchema(tx graph.Transaction, schema graph.Schema) return nil } + +func (s *SchemaManager) AssertSchema(ctx context.Context, schema graph.Schema) error { + s.lock.Lock() + defer s.lock.Unlock() + + return s.database.WriteTransaction(ctx, func(tx graph.Transaction) error { + return s.assertSchema(tx, schema) + }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)) +} diff --git a/packages/go/dawgs/drivers/pg/mapper.go b/packages/go/dawgs/drivers/pg/mapper.go index 061e1ec9b7..f4f7fdb05e 100644 --- a/packages/go/dawgs/drivers/pg/mapper.go +++ b/packages/go/dawgs/drivers/pg/mapper.go @@ -17,12 +17,13 @@ package pg import ( + "context" "fmt" "github.com/specterops/bloodhound/dawgs/graph" ) -func mapValue(kindMapper KindMapper) func(rawValue, target any) (bool, error) { +func mapValue(ctx context.Context, kindMapper KindMapper) func(rawValue, target any) (bool, error) { return func(rawValue, target any) (bool, error) { switch typedTarget := target.(type) { case *graph.Relationship: @@ -32,7 +33,7 @@ func mapValue(kindMapper KindMapper) func(rawValue, target any) (bool, error) { edge := edgeComposite{} if edge.TryMap(compositeMap) { - if err := edge.ToRelationship(kindMapper, typedTarget); err != nil { + if err := edge.ToRelationship(ctx, kindMapper, typedTarget); err != nil { return false, err } } else { @@ -47,7 +48,7 @@ func mapValue(kindMapper KindMapper) func(rawValue, target any) (bool, error) { node := nodeComposite{} if node.TryMap(compositeMap) { - if err := node.ToNode(kindMapper, typedTarget); err != nil { + if err := node.ToNode(ctx, kindMapper, typedTarget); err != nil { return false, err } } else { @@ -62,7 +63,7 @@ func mapValue(kindMapper KindMapper) func(rawValue, target any) (bool, error) { path := pathComposite{} if path.TryMap(compositeMap) { - if err := path.ToPath(kindMapper, typedTarget); err != nil { + if err := path.ToPath(ctx, kindMapper, typedTarget); err != nil { return false, err } } else { @@ -78,6 +79,6 @@ func mapValue(kindMapper KindMapper) func(rawValue, target any) (bool, error) { } } -func NewValueMapper(values []any, kindMapper KindMapper) graph.ValueMapper { - return graph.NewValueMapper(values, mapValue(kindMapper)) +func NewValueMapper(ctx context.Context, values []any, kindMapper KindMapper) graph.ValueMapper { + return graph.NewValueMapper(values, mapValue(ctx, kindMapper)) } diff --git a/packages/go/dawgs/drivers/pg/node_test.go b/packages/go/dawgs/drivers/pg/node_test.go index 146cf33665..e3d2b21ca2 100644 --- a/packages/go/dawgs/drivers/pg/node_test.go +++ b/packages/go/dawgs/drivers/pg/node_test.go @@ -20,6 +20,8 @@ import ( "context" "testing" + "github.com/specterops/bloodhound/dawgs/drivers/pg/pgutil" + "github.com/specterops/bloodhound/dawgs/graph" graph_mocks "github.com/specterops/bloodhound/dawgs/graph/mocks" "github.com/specterops/bloodhound/dawgs/query" @@ -27,58 +29,31 @@ import ( "go.uber.org/mock/gomock" ) -type testKindMapper struct { - known map[string]int16 -} - -func (s testKindMapper) MapKindID(kindID int16) (graph.Kind, bool) { - panic("implement me") -} - -func (s testKindMapper) MapKindIDs(kindIDs ...int16) (graph.Kinds, []int16) { - panic("implement me") -} - -func (s testKindMapper) MapKind(kind graph.Kind) (int16, bool) { - panic("implement me") -} - -func (s testKindMapper) AssertKinds(tx graph.Transaction, kinds graph.Kinds) ([]int16, error) { - panic("implement me") -} +var ( + NodeKind1 = graph.StringKind("NodeKind1") + NodeKind2 = graph.StringKind("NodeKind2") + EdgeKind1 = graph.StringKind("EdgeKind1") + EdgeKind2 = graph.StringKind("EdgeKind2") +) -func (s testKindMapper) MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { - var ( - kindIDs = make([]int16, 0, len(kinds)) - missingKinds = make([]graph.Kind, 0, len(kinds)) - ) +func newKindMapper() KindMapper { + mapper := pgutil.NewInMemoryKindMapper() - for _, kind := range kinds { - if kindID, hasKind := s.known[kind.String()]; hasKind { - kindIDs = append(kindIDs, kindID) - } else { - missingKinds = append(missingKinds, kind) - } - } + // This is here to make SQL output a little more predictable for test cases + mapper.Put(NodeKind1) + mapper.Put(NodeKind2) + mapper.Put(EdgeKind1) + mapper.Put(EdgeKind2) - return kindIDs, missingKinds + return mapper } func TestNodeQuery(t *testing.T) { var ( - mockCtrl = gomock.NewController(t) - mockTx = graph_mocks.NewMockTransaction(mockCtrl) - mockResult = graph_mocks.NewMockResult(mockCtrl) - - kindMapper = testKindMapper{ - known: map[string]int16{ - "NodeKindA": 1, - "NodeKindB": 2, - "EdgeKindA": 3, - "EdgeKindB": 4, - }, - } - + mockCtrl = gomock.NewController(t) + mockTx = graph_mocks.NewMockTransaction(mockCtrl) + mockResult = graph_mocks.NewMockResult(mockCtrl) + kindMapper = newKindMapper() nodeQueryInst = &nodeQuery{ liveQuery: newLiveQuery(context.Background(), mockTx, kindMapper), } diff --git a/packages/go/dawgs/drivers/pg/pg.go b/packages/go/dawgs/drivers/pg/pg.go index 6888d2a9c4..6e22ad64c9 100644 --- a/packages/go/dawgs/drivers/pg/pg.go +++ b/packages/go/dawgs/drivers/pg/pg.go @@ -87,12 +87,16 @@ func newDatabase(connectionString string) (*Driver, error) { if pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg); err != nil { return nil, err } else { - return &Driver{ + driverInst := &Driver{ pool: pool, - schemaManager: NewSchemaManager(), defaultTransactionTimeout: defaultTransactionTimeout, batchWriteSize: defaultBatchWriteSize, - }, nil + } + + // Because the schema manager will act on the database on its own it needs a reference to the driver + // TODO: This cyclical dependency might want to be unwound + driverInst.schemaManager = NewSchemaManager(driverInst) + return driverInst, nil } } } diff --git a/packages/go/dawgs/drivers/pg/pgutil/kindmapper.go b/packages/go/dawgs/drivers/pg/pgutil/kindmapper.go new file mode 100644 index 0000000000..66a2eb5c19 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/pgutil/kindmapper.go @@ -0,0 +1,112 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgutil + +import ( + "context" + "fmt" + + "github.com/specterops/bloodhound/dawgs/graph" +) + +var nextKindID = int16(1) + +type InMemoryKindMapper struct { + KindToID map[graph.Kind]int16 + IDToKind map[int16]graph.Kind +} + +func NewInMemoryKindMapper() *InMemoryKindMapper { + return &InMemoryKindMapper{ + KindToID: map[graph.Kind]int16{}, + IDToKind: map[int16]graph.Kind{}, + } +} + +func (s *InMemoryKindMapper) MapKindID(ctx context.Context, kindID int16) (graph.Kind, error) { + if kind, hasKind := s.IDToKind[kindID]; hasKind { + return kind, nil + } + + return nil, fmt.Errorf("kind not found for id %d", kindID) +} + +func (s *InMemoryKindMapper) MapKindIDs(ctx context.Context, kindIDs ...int16) (graph.Kinds, error) { + kinds := make(graph.Kinds, len(kindIDs)) + + for idx, kindID := range kindIDs { + if kind, err := s.MapKindID(ctx, kindID); err != nil { + return nil, err + } else { + kinds[idx] = kind + } + } + + return kinds, nil +} + +func (s *InMemoryKindMapper) MapKind(ctx context.Context, kind graph.Kind) (int16, error) { + if id, hasID := s.KindToID[kind]; hasID { + return id, nil + } + + return 0, fmt.Errorf("no id found for kind %s", kind) +} + +func (s *InMemoryKindMapper) mapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { + var ( + ids = make([]int16, 0, len(kinds)) + missing = make(graph.Kinds, 0, len(kinds)) + ) + + for _, kind := range kinds { + if id, found := s.KindToID[kind]; !found { + missing = append(missing, kind) + } else { + ids = append(ids, id) + } + } + + return ids, missing +} +func (s *InMemoryKindMapper) MapKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) { + if ids, missing := s.mapKinds(kinds); len(missing) > 0 { + return nil, fmt.Errorf("missing kinds: %v", missing) + } else { + return ids, nil + } +} + +func (s *InMemoryKindMapper) AssertKinds(ctx context.Context, kinds graph.Kinds) ([]int16, error) { + ids, missing := s.mapKinds(kinds) + + for _, kind := range missing { + ids = append(ids, s.Put(kind)) + } + + return ids, nil +} + +func (s *InMemoryKindMapper) Put(kind graph.Kind) int16 { + kindID := nextKindID + nextKindID += 1 + + s.KindToID[kind] = kindID + s.IDToKind[kindID] = kind + + return kindID +} diff --git a/packages/go/dawgs/drivers/pg/query.go b/packages/go/dawgs/drivers/pg/query.go index 71d996ce6d..d8dee2f61a 100644 --- a/packages/go/dawgs/drivers/pg/query.go +++ b/packages/go/dawgs/drivers/pg/query.go @@ -43,7 +43,7 @@ func newLiveQuery(ctx context.Context, tx graph.Transaction, kindMapper KindMapp func (s *liveQuery) runRegularQuery(allShortestPaths bool) graph.Result { if regularQuery, err := s.queryBuilder.Build(allShortestPaths); err != nil { return graph.NewErrorResult(err) - } else if translation, err := translate.FromCypher(regularQuery, s.kindMapper, false); err != nil { + } else if translation, err := translate.FromCypher(s.ctx, regularQuery, s.kindMapper, false); err != nil { return graph.NewErrorResult(err) } else { return s.tx.Raw(translation.Statement, translation.Parameters) diff --git a/packages/go/dawgs/drivers/pg/result.go b/packages/go/dawgs/drivers/pg/result.go index 4d841dbaff..fa9acb9e02 100644 --- a/packages/go/dawgs/drivers/pg/result.go +++ b/packages/go/dawgs/drivers/pg/result.go @@ -17,13 +17,14 @@ package pg import ( - "fmt" + "context" "github.com/jackc/pgx/v5" "github.com/specterops/bloodhound/dawgs/graph" ) type queryResult struct { + ctx context.Context rows pgx.Rows kindMapper KindMapper } @@ -36,7 +37,7 @@ func (s *queryResult) Values() (graph.ValueMapper, error) { if values, err := s.rows.Values(); err != nil { return nil, err } else { - return NewValueMapper(values, s.kindMapper), nil + return NewValueMapper(s.ctx, values, s.kindMapper), nil } } @@ -72,24 +73,24 @@ func (s *queryResult) Scan(targets ...any) error { for idx, pgTarget := range pgTargets { switch typedPGTarget := pgTarget.(type) { case *pathComposite: - if err := typedPGTarget.ToPath(s.kindMapper, targets[idx].(*graph.Path)); err != nil { + if err := typedPGTarget.ToPath(s.ctx, s.kindMapper, targets[idx].(*graph.Path)); err != nil { return err } case *edgeComposite: - if err := typedPGTarget.ToRelationship(s.kindMapper, targets[idx].(*graph.Relationship)); err != nil { + if err := typedPGTarget.ToRelationship(s.ctx, s.kindMapper, targets[idx].(*graph.Relationship)); err != nil { return err } case *nodeComposite: - if err := typedPGTarget.ToNode(s.kindMapper, targets[idx].(*graph.Node)); err != nil { + if err := typedPGTarget.ToNode(s.ctx, s.kindMapper, targets[idx].(*graph.Node)); err != nil { return err } case *int16: if kindPtr, isKindType := targets[idx].(*graph.Kind); isKindType { - if kind, hasKind := s.kindMapper.MapKindID(*typedPGTarget); !hasKind { - return fmt.Errorf("unable to map kind ID %d", *typedPGTarget) + if kind, err := s.kindMapper.MapKindID(s.ctx, *typedPGTarget); err != nil { + return err } else { *kindPtr = kind } @@ -97,8 +98,8 @@ func (s *queryResult) Scan(targets ...any) error { case *[]int16: if kindsPtr, isKindsType := targets[idx].(*graph.Kinds); isKindsType { - if kinds, missingKindIDs := s.kindMapper.MapKindIDs(*typedPGTarget...); len(missingKindIDs) > 0 { - return fmt.Errorf("unable to map kind IDs %+v", missingKindIDs) + if kinds, err := s.kindMapper.MapKindIDs(s.ctx, *typedPGTarget...); err != nil { + return err } else { *kindsPtr = kinds } diff --git a/packages/go/dawgs/drivers/pg/transaction.go b/packages/go/dawgs/drivers/pg/transaction.go index af35861636..184f6cf146 100644 --- a/packages/go/dawgs/drivers/pg/transaction.go +++ b/packages/go/dawgs/drivers/pg/transaction.go @@ -131,7 +131,7 @@ func (s *transaction) getTargetGraph() (model.Graph, error) { func (s *transaction) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) { if graphTarget, err := s.getTargetGraph(); err != nil { return nil, err - } else if kindIDSlice, err := s.schemaManager.AssertKinds(s, kinds); err != nil { + } else if kindIDSlice, err := s.schemaManager.AssertKinds(s.ctx, kinds); err != nil { return nil, err } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(properties); err != nil { return nil, err @@ -192,7 +192,7 @@ func (s *transaction) Nodes() graph.NodeQuery { func (s *transaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { if graphTarget, err := s.getTargetGraph(); err != nil { return nil, err - } else if kindIDSlice, err := s.schemaManager.AssertKinds(s, graph.Kinds{kind}); err != nil { + } else if kindIDSlice, err := s.schemaManager.AssertKinds(s.ctx, graph.Kinds{kind}); err != nil { return nil, err } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(properties); err != nil { return nil, err @@ -280,7 +280,7 @@ func (s *transaction) query(query string, parameters map[string]any) (pgx.Rows, func (s *transaction) Query(query string, parameters map[string]any) graph.Result { if parsedQuery, err := frontend.ParseCypher(frontend.NewContext(), query); err != nil { return graph.NewErrorResult(err) - } else if translated, err := translate.Translate(parsedQuery, s.schemaManager); err != nil { + } else if translated, err := translate.Translate(s.ctx, parsedQuery, s.schemaManager, parameters); err != nil { return graph.NewErrorResult(err) } else if sqlQuery, err := translate.Translated(translated); err != nil { return graph.NewErrorResult(err) @@ -294,6 +294,7 @@ func (s *transaction) Raw(query string, parameters map[string]any) graph.Result return graph.NewErrorResult(err) } else { return &queryResult{ + ctx: s.ctx, rows: rows, kindMapper: s.schemaManager, } diff --git a/packages/go/dawgs/drivers/pg/types.go b/packages/go/dawgs/drivers/pg/types.go index 3ce3ae9f01..5e9d63c6fd 100644 --- a/packages/go/dawgs/drivers/pg/types.go +++ b/packages/go/dawgs/drivers/pg/types.go @@ -17,6 +17,7 @@ package pg import ( + "context" "fmt" "github.com/specterops/bloodhound/dawgs/graph" @@ -163,9 +164,9 @@ func (s *edgeComposite) FromMap(compositeMap map[string]any) error { return nil } -func (s *edgeComposite) ToRelationship(kindMapper KindMapper, relationship *graph.Relationship) error { - if kinds, missingIDs := kindMapper.MapKindIDs(s.KindID); len(missingIDs) > 0 { - return fmt.Errorf("edge references the following unknown kind IDs: %v", missingIDs) +func (s *edgeComposite) ToRelationship(ctx context.Context, kindMapper KindMapper, relationship *graph.Relationship) error { + if kinds, err := kindMapper.MapKindIDs(ctx, s.KindID); err != nil { + return err } else { relationship.Kind = kinds[0] } @@ -206,9 +207,9 @@ func (s *nodeComposite) FromMap(compositeMap map[string]any) error { return nil } -func (s *nodeComposite) ToNode(kindMapper KindMapper, node *graph.Node) error { - if kinds, missingIDs := kindMapper.MapKindIDs(s.KindIDs...); len(missingIDs) > 0 { - return fmt.Errorf("node references the following unknown kind IDs: %v", missingIDs) +func (s *nodeComposite) ToNode(ctx context.Context, kindMapper KindMapper, node *graph.Node) error { + if kinds, err := kindMapper.MapKindIDs(ctx, s.KindIDs...); err != nil { + return err } else { node.Kinds = kinds } @@ -276,13 +277,13 @@ func (s *pathComposite) FromMap(compositeMap map[string]any) error { return nil } -func (s *pathComposite) ToPath(kindMapper KindMapper, path *graph.Path) error { +func (s *pathComposite) ToPath(ctx context.Context, kindMapper KindMapper, path *graph.Path) error { path.Nodes = make([]*graph.Node, len(s.Nodes)) for idx, pgNode := range s.Nodes { dawgsNode := &graph.Node{} - if err := pgNode.ToNode(kindMapper, dawgsNode); err != nil { + if err := pgNode.ToNode(ctx, kindMapper, dawgsNode); err != nil { return err } @@ -294,7 +295,7 @@ func (s *pathComposite) ToPath(kindMapper KindMapper, path *graph.Path) error { for idx, pgEdge := range s.Edges { dawgsRelationship := &graph.Relationship{} - if err := pgEdge.ToRelationship(kindMapper, dawgsRelationship); err != nil { + if err := pgEdge.ToRelationship(ctx, kindMapper, dawgsRelationship); err != nil { return err }