From c9643fb41ddcb54e858c33fba6027afab0ecb436 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Fri, 17 Jan 2025 08:43:00 -0800 Subject: [PATCH] cysql fixes --- packages/go/cypher/frontend/query.go | 39 +- packages/go/cypher/models/cypher/functions.go | 1 + packages/go/cypher/models/cypher/model.go | 4 + .../go/cypher/models/pgsql/format/format.go | 5 +- packages/go/cypher/models/pgsql/functions.go | 10 + .../go/cypher/models/pgsql/identifiers.go | 4 + packages/go/cypher/models/pgsql/model.go | 27 +- packages/go/cypher/models/pgsql/pgtypes.go | 1 + .../test/translation_cases/multipart.sql | 188 ++++++ .../test/translation_cases/shortest_paths.sql | 21 + .../translation_cases/stepwise_traversal.sql | 11 + .../cypher/models/pgsql/translate/building.go | 69 +- .../cypher/models/pgsql/translate/delete.go | 8 +- .../models/pgsql/translate/expansion.go | 64 +- .../models/pgsql/translate/expression.go | 4 +- .../cypher/models/pgsql/translate/function.go | 237 +++++++ .../go/cypher/models/pgsql/translate/model.go | 102 ++- .../go/cypher/models/pgsql/translate/node.go | 41 +- .../cypher/models/pgsql/translate/pattern.go | 28 +- .../models/pgsql/translate/predicate.go | 36 +- .../models/pgsql/translate/projection.go | 54 +- .../go/cypher/models/pgsql/translate/query.go | 182 +++++ .../models/pgsql/translate/relationship.go | 25 +- .../cypher/models/pgsql/translate/renamer.go | 24 +- .../cypher/models/pgsql/translate/tracking.go | 73 +- .../models/pgsql/translate/tracking_test.go | 28 + .../models/pgsql/translate/translation.go | 68 +- .../models/pgsql/translate/translator.go | 626 ++++-------------- .../cypher/models/pgsql/translate/update.go | 8 +- packages/go/cypher/models/walk/walk.go | 6 +- packages/go/cypher/models/walk/walk_cypher.go | 8 +- packages/go/cypher/models/walk/walk_pgsql.go | 6 + packages/go/dawgs/drivers/pg/tooling.go | 2 - 33 files changed, 1287 insertions(+), 723 deletions(-) create mode 100644 packages/go/cypher/models/pgsql/test/translation_cases/multipart.sql create mode 100644 packages/go/cypher/models/pgsql/translate/function.go create mode 100644 packages/go/cypher/models/pgsql/translate/query.go create mode 100644 packages/go/cypher/models/pgsql/translate/tracking_test.go diff --git a/packages/go/cypher/frontend/query.go b/packages/go/cypher/frontend/query.go index 04699b8386..9789e3097a 100644 --- a/packages/go/cypher/frontend/query.go +++ b/packages/go/cypher/frontend/query.go @@ -353,44 +353,61 @@ func (s *SinglePartQueryVisitor) ExitOC_UpdatingClause(ctx *parser.OC_UpdatingCl type MultiPartQueryVisitor struct { BaseVisitor - Query *cypher.MultiPartQuery + Query *cypher.MultiPartQuery + partIdx int } func NewMultiPartQueryVisitor() *MultiPartQueryVisitor { return &MultiPartQueryVisitor{ - Query: cypher.NewMultiPartQuery(), + Query: cypher.NewMultiPartQuery(), + partIdx: 0, } } func (s *MultiPartQueryVisitor) EnterOC_ReadingClause(ctx *parser.OC_ReadingClauseContext) { + // If the part index is equal to the length of parts then this signifies that a new query part + // is required. We do not advance the index here - this is done with the following `with` + // cypher AST component + if len(s.Query.Parts) == s.partIdx { + s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart()) + } + s.ctx.Enter(NewReadingClauseVisitor()) } func (s *MultiPartQueryVisitor) ExitOC_ReadingClause(ctx *parser.OC_ReadingClauseContext) { - part := cypher.NewMultiPartQueryPart() - part.AddReadingClause(s.ctx.Exit().(*ReadingClauseVisitor).ReadingClause) - s.Query.Parts = append(s.Query.Parts, part) + s.Query.CurrentPart().AddReadingClause(s.ctx.Exit().(*ReadingClauseVisitor).ReadingClause) } func (s *MultiPartQueryVisitor) EnterOC_UpdatingClause(ctx *parser.OC_UpdatingClauseContext) { + if len(s.Query.Parts) == s.partIdx { + s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart()) + } + s.ctx.Enter(NewUpdatingClauseVisitor()) } func (s *MultiPartQueryVisitor) ExitOC_UpdatingClause(ctx *parser.OC_UpdatingClauseContext) { + // Make sure to mark that this multipart query part contains a mutation (non-read operation) s.ctx.HasMutation = true - part := cypher.NewMultiPartQueryPart() - part.AddUpdatingClause(s.ctx.Exit().(*UpdatingClauseVisitor).UpdatingClause) - s.Query.Parts = append(s.Query.Parts, part) + + s.Query.CurrentPart().AddUpdatingClause(s.ctx.Exit().(*UpdatingClauseVisitor).UpdatingClause) } func (s *MultiPartQueryVisitor) EnterOC_With(ctx *parser.OC_WithContext) { + if len(s.Query.Parts) == s.partIdx { + s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart()) + } + s.ctx.Enter(NewWithVisitor()) } func (s *MultiPartQueryVisitor) ExitOC_With(ctx *parser.OC_WithContext) { - part := cypher.NewMultiPartQueryPart() - part.With = s.ctx.Exit().(*WithVisitor).With - s.Query.Parts = append(s.Query.Parts, part) + s.Query.CurrentPart().With = s.ctx.Exit().(*WithVisitor).With + + // Advance the part index so a new multipart query part gets allocated for the next reading + // or updating clause + s.partIdx += 1 } func (s *MultiPartQueryVisitor) EnterOC_SinglePartQuery(ctx *parser.OC_SinglePartQueryContext) { diff --git a/packages/go/cypher/models/cypher/functions.go b/packages/go/cypher/models/cypher/functions.go index 59cd959c9d..1232e8d72f 100644 --- a/packages/go/cypher/models/cypher/functions.go +++ b/packages/go/cypher/models/cypher/functions.go @@ -34,6 +34,7 @@ const ( ToIntegerFunction = "toint" ListSizeFunction = "size" CoalesceFunction = "coalesce" + CollectFunction = "collect" // ITTC - Instant Type; Temporal Component (https://neo4j.com/docs/cypher-manual/current/functions/temporal/) ITTCYear = "year" diff --git a/packages/go/cypher/models/cypher/model.go b/packages/go/cypher/models/cypher/model.go index 064177c356..10e39ddc77 100644 --- a/packages/go/cypher/models/cypher/model.go +++ b/packages/go/cypher/models/cypher/model.go @@ -260,6 +260,10 @@ type MultiPartQuery struct { SinglePartQuery *SinglePartQuery } +func (s *MultiPartQuery) CurrentPart() *MultiPartQueryPart { + return s.Parts[len(s.Parts)-1] +} + func NewMultiPartQuery() *MultiPartQuery { return &MultiPartQuery{} } diff --git a/packages/go/cypher/models/pgsql/format/format.go b/packages/go/cypher/models/pgsql/format/format.go index 86454300ac..db5293125c 100644 --- a/packages/go/cypher/models/pgsql/format/format.go +++ b/packages/go/cypher/models/pgsql/format/format.go @@ -492,7 +492,7 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error { } case pgsql.ExistsExpression: - exprStack = append(exprStack, pgsql.FormattingLiteral(")"), typedNextExpr.Subquery, pgsql.FormattingLiteral("exists (")) + exprStack = append(exprStack, typedNextExpr.Subquery, pgsql.FormattingLiteral("exists ")) if typedNextExpr.Negated { exprStack = append(exprStack, pgsql.FormattingLiteral("not ")) @@ -517,6 +517,9 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error { } } + case pgsql.Subquery: + exprStack = append(exprStack, pgsql.FormattingLiteral(")"), typedNextExpr.Query, pgsql.FormattingLiteral("(")) + default: return fmt.Errorf("unable to format pgsql node type: %T", nextExpr) } diff --git a/packages/go/cypher/models/pgsql/functions.go b/packages/go/cypher/models/pgsql/functions.go index be5267764a..d58abcf1a3 100644 --- a/packages/go/cypher/models/pgsql/functions.go +++ b/packages/go/cypher/models/pgsql/functions.go @@ -44,3 +44,13 @@ const ( FunctionEdgesToPath Identifier = "edges_to_path" FunctionExtract Identifier = "extract" ) + +func IsAggregateFunction(function Identifier) bool { + switch function { + case FunctionCount, FunctionArrayAggregate: + return true + + default: + return false + } +} diff --git a/packages/go/cypher/models/pgsql/identifiers.go b/packages/go/cypher/models/pgsql/identifiers.go index 582290bd93..68b6b2c7c6 100644 --- a/packages/go/cypher/models/pgsql/identifiers.go +++ b/packages/go/cypher/models/pgsql/identifiers.go @@ -73,6 +73,10 @@ func AsIdentifierSet(identifiers ...Identifier) *IdentifierSet { return newSet } +func (s *IdentifierSet) Clear() { + clear(s.identifiers) +} + func (s *IdentifierSet) Len() int { return len(s.identifiers) } diff --git a/packages/go/cypher/models/pgsql/model.go b/packages/go/cypher/models/pgsql/model.go index d5f3d06a40..627cd212da 100644 --- a/packages/go/cypher/models/pgsql/model.go +++ b/packages/go/cypher/models/pgsql/model.go @@ -228,6 +228,14 @@ type Subquery struct { Query Query } +func (s Subquery) NodeType() string { + return "subquery" +} + +func (s Subquery) AsExpression() Expression { + return s +} + // not type UnaryExpression struct { Operator Expression @@ -321,6 +329,10 @@ type Parenthetical struct { Expression Expression } +func (s Parenthetical) AsSelectItem() SelectItem { + return s +} + func (s Parenthetical) NodeType() string { return "parenthetical" } @@ -1022,7 +1034,6 @@ func (s Select) NodeType() string { // select 1 // union // select 2; - type SetOperation struct { Operator Operator LOperand SetExpression @@ -1047,7 +1058,7 @@ func (s SetOperation) NodeType() string { // // [not] exists() type ExistsExpression struct { - Subquery Query + Subquery Subquery Negated bool } @@ -1132,18 +1143,6 @@ func (s Query) NodeType() string { return "query" } -func BinaryExpressionJoinTyped(optional Expression, operator Operator, conjoined *BinaryExpression) *BinaryExpression { - if optional == nil { - return conjoined - } - - return NewBinaryExpression( - conjoined, - operator, - optional, - ) -} - func BinaryExpressionJoin(optional Expression, operator Operator, conjoined Expression) Expression { if optional == nil { return conjoined diff --git a/packages/go/cypher/models/pgsql/pgtypes.go b/packages/go/cypher/models/pgsql/pgtypes.go index 1d726a1699..a2630efffd 100644 --- a/packages/go/cypher/models/pgsql/pgtypes.go +++ b/packages/go/cypher/models/pgsql/pgtypes.go @@ -110,6 +110,7 @@ const ( TimestampWithoutTimeZone DataType = "timestamp without time zone" Scope DataType = "scope" + InlineProjection DataType = "inline_projection" ParameterIdentifier DataType = "parameter_identifier" ExpansionPattern DataType = "expansion_pattern" ExpansionPath DataType = "expansion_path" diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/multipart.sql b/packages/go/cypher/models/pgsql/test/translation_cases/multipart.sql new file mode 100644 index 0000000000..07074538a7 --- /dev/null +++ b/packages/go/cypher/models/pgsql/test/translation_cases/multipart.sql @@ -0,0 +1,188 @@ +-- Copyright 2025 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- case: with '1' as target match (n:NodeKind1) where n.value = target return n +with s0 as (select '1' as i0), + s1 as (select s0.i0 as i0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from s0, + node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and n0.properties ->> 'value' = s0.i0) +select s1.n0 as n +from s1; + +-- case: match (n:NodeKind1) where n.value = 1 with n match (b) where id(b) = id(n) return b +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties ->> 'value')::int8 = 1) + select s1.n0 as n0 + from s1), + s2 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from s0, + node n1 + where n1.id = (s0.n0).id) +select s2.n1 as b +from s2; + +-- case: match (n:NodeKind1) where n.value = 1 with n match (f) where f.name = 'me' with f match (b) where id(b) = id(f) return b +with s0 as (with s1 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties ->> 'value')::int8 = 1) + select s1.n0 as n0 + from s1), + s2 as (with s3 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from node n1 + where n1.properties ->> 'name' = 'me') + select s3.n1 as n1 + from s3), + s4 as (select s2.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 + from s2, + node n2 + where n2.id = (s2.n1).id) +select s4.n2 as b +from s4; + +-- case: match (n:NodeKind1)-[:EdgeKind1*1..]->(:NodeKind2)-[:EdgeKind2]->(m:NodeKind1) where (n:NodeKind1 or n:NodeKind2) and n.enabled = true with m, collect(distinct(n)) as p where size(p) >= 10 return m +with s0 as (with s1 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, + e0.end_id, + 1, + n1.kind_ids operator (pg_catalog.&&) array [2]::int2[], + e0.start_id = e0.end_id, + array [e0.id] + from edge e0 + join node n0 + on + n0.kind_ids operator (pg_catalog.&&) + array [1]::int2[] and + (n0.kind_ids operator (pg_catalog.&&) + array [1]::int2[] or + n0.kind_ids operator (pg_catalog.&&) + array [2]::int2[]) and + (n0.properties ->> 'enabled')::bool = + true and + n0.id = + e0.start_id + join node n1 on n1.id = e0.end_id + where e0.kind_id = any (array [3]::int2[]) + union + select ex0.root_id, + e0.end_id, + ex0.depth + 1, + n1.kind_ids operator (pg_catalog.&&) array [2]::int2[], + e0.id = any (ex0.path), + ex0.path || e0.id + from ex0 + join edge e0 on e0.start_id = ex0.next_id + join node n1 on n1.id = e0.end_id + where ex0.depth < 5 + and not ex0.is_cycle + and e0.kind_id = any (array [3]::int2[])) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id + where ex0.satisfied), + s2 as (select s1.e0 as e0, + s1.ep0 as ep0, + s1.n0 as n0, + s1.n1 as n1, + (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, + (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 + from s1, + edge e1 + join node n2 + on n2.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n2.id = e1.end_id + where e1.kind_id = any (array [4]::int2[]) + and (s1.n1).id = e1.start_id) + select s2.n2 as n2, array_agg(distinct (n0))::nodecomposite[] as i0 + from s2 + group by n2) +select s0.n2 as m +from s0 +where array_length(s0.i0, 1)::int >= 10; + +-- case: match (n:NodeKind1)-[:EdgeKind1*1..]->(:NodeKind2)-[:EdgeKind2]->(m:NodeKind1) where (n:NodeKind1 or n:NodeKind2) and n.enabled = true with m, count(distinct(n)) as p where p >= 10 return m +with s0 as (with s1 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, + e0.end_id, + 1, + n1.kind_ids operator (pg_catalog.&&) array [2]::int2[], + e0.start_id = e0.end_id, + array [e0.id] + from edge e0 + join node n0 + on + n0.kind_ids operator (pg_catalog.&&) + array [1]::int2[] and + (n0.kind_ids operator (pg_catalog.&&) + array [1]::int2[] or + n0.kind_ids operator (pg_catalog.&&) + array [2]::int2[]) and + (n0.properties ->> 'enabled')::bool = + true and + n0.id = + e0.start_id + join node n1 on n1.id = e0.end_id + where e0.kind_id = any (array [3]::int2[]) + union + select ex0.root_id, + e0.end_id, + ex0.depth + 1, + n1.kind_ids operator (pg_catalog.&&) array [2]::int2[], + e0.id = any (ex0.path), + ex0.path || e0.id + from ex0 + join edge e0 on e0.start_id = ex0.next_id + join node n1 on n1.id = e0.end_id + where ex0.depth < 5 + and not ex0.is_cycle + and e0.kind_id = any (array [3]::int2[])) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id + where ex0.satisfied), + s2 as (select s1.e0 as e0, + s1.ep0 as ep0, + s1.n0 as n0, + s1.n1 as n1, + (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, + (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 + from s1, + edge e1 + join node n2 + on n2.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n2.id = e1.end_id + where e1.kind_id = any (array [4]::int2[]) + and (s1.n1).id = e1.start_id) + select s2.n2 as n2, count((n0))::int8 as i0 + from s2 + group by n2) +select s0.n2 as m +from s0 +where s0.i0 >= 10; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql index d8fc95bc8f..e6e5c459ba 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql @@ -71,3 +71,24 @@ with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) select edges_to_path(variadic ep0)::pathcomposite as p from s0 limit 1000; + +-- case: match p=shortestPath((n:NodeKind1)-[:EdgeKind1*1..]->(m)) where 'admin_tier_0' in split(m.system_tags, ' ') and n.objectid ends with '-513' and m<>n return p limit 1000 +-- cypher_params: {} +-- pgsql_params: {"pi0":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select e0.start_id, e0.end_id, 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.start_id = e0.end_id, array [e0.id] from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.properties ->> 'objectid' like '%-513' and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [3]::int2[]);", "pi1":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select ex0.root_id, e0.end_id, ex0.depth + 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.id = any (ex0.path), ex0.path || e0.id from pathspace ex0 join edge e0 on e0.start_id = ex0.next_id join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle and e0.kind_id = any (array [3]::int2[]);"} +with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) + as (select * from asp_harness(@pi0::text, @pi1::text, 5)) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id + where ex0.satisfied + and n1.id <> n0.id) +select edges_to_path(variadic ep0)::pathcomposite as p +from s0 +limit 1000; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql index 86347ee1ef..db5bcbfafd 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql @@ -315,6 +315,17 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite select s0.n1 as n2 from s0; +-- case: match (n1)-[]->(n2) where n2 <> n1 return n2 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from edge e0 + join node n0 on n0.id = e0.start_id + join node n1 on n1.id = e0.end_id + where n1.id <> n0.id) +select s0.n1 as n2 +from s0; + -- case: match ()-[r]->()-[e]->(n) where r <> e return n with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, diff --git a/packages/go/cypher/models/pgsql/translate/building.go b/packages/go/cypher/models/pgsql/translate/building.go index 3d2bc6af01..377cbbfdc8 100644 --- a/packages/go/cypher/models/pgsql/translate/building.go +++ b/packages/go/cypher/models/pgsql/translate/building.go @@ -20,54 +20,95 @@ import ( "github.com/specterops/bloodhound/cypher/models/pgsql" ) -func (s *Translator) buildProjection(scope *Scope) error { +func (s *Translator) buildInlineProjection(part *QueryPart) (pgsql.Select, error) { + var sqlSelect pgsql.Select + + if part.projections.Frame != nil { + sqlSelect.From = []pgsql.FromClause{{ + Source: part.projections.Frame.Binding.Identifier, + }} + } + + if projectionConstraint, err := s.treeTranslator.ConsumeAll(); err != nil { + return sqlSelect, err + } else { + sqlSelect.Where = projectionConstraint.Expression + } + + for _, projection := range part.projections.Items { + builtProjection := projection.SelectItem + + if projection.Alias.Set { + builtProjection = &pgsql.AliasedExpression{ + Expression: builtProjection, + Alias: projection.Alias, + } + } + + sqlSelect.Projection = append(sqlSelect.Projection, builtProjection) + } + + if len(part.projections.GroupBy) > 0 { + for _, groupBy := range part.projections.GroupBy { + sqlSelect.GroupBy = append(sqlSelect.GroupBy, groupBy) + } + } + + return sqlSelect, nil +} + +func (s *Translator) buildTailProjection() error { var ( + currentPart = s.query.CurrentPart() + currentFrame = s.query.Scope.CurrentFrame() singlePartQuerySelect = pgsql.Select{} ) singlePartQuerySelect.From = []pgsql.FromClause{{ Source: pgsql.TableReference{ - Name: pgsql.CompoundIdentifier{scope.CurrentFrameBinding().Identifier}, + Name: pgsql.CompoundIdentifier{currentFrame.Binding.Identifier}, }, }} if projectionConstraint, err := s.treeTranslator.ConsumeAll(); err != nil { return err - } else if projection, err := buildExternalProjection(scope, s.projections.Projections); err != nil { + } else if projection, err := buildExternalProjection(s.query.Scope, currentPart.projections.Items); err != nil { + return err + } else if err := RewriteExpressionIdentifiers(projectionConstraint.Expression, currentFrame.Binding.Identifier, nil); err != nil { return err } else { singlePartQuerySelect.Projection = projection singlePartQuerySelect.Where = projectionConstraint.Expression } - s.query.Model.Body = singlePartQuerySelect + currentPart.Model.Body = singlePartQuerySelect - if s.query.Skip.Set { - s.query.Model.Offset = s.query.Skip + if currentPart.Skip.Set { + currentPart.Model.Offset = currentPart.Skip } - if s.query.Limit.Set { - s.query.Model.Limit = s.query.Limit + if currentPart.Limit.Set { + currentPart.Model.Limit = currentPart.Limit } - if len(s.query.OrderBy) > 0 { - s.query.Model.OrderBy = s.query.OrderBy + if len(currentPart.OrderBy) > 0 { + currentPart.Model.OrderBy = currentPart.OrderBy } return nil } -func (s *Translator) buildMatch(scope *Scope) error { - for _, part := range s.match.Pattern.Parts { +func (s *Translator) buildMatch() error { + for _, part := range s.query.CurrentPart().match.Pattern.Parts { // Pattern can't be in scope at time of select as the pattern's scope directly depends on the // pattern parts - if err := s.buildPatternPart(scope, part); err != nil { + if err := s.buildPatternPart(part); err != nil { return err } // Declare the pattern variable in scope if set if part.PatternBinding.Set { - scope.Declare(part.PatternBinding.Value.Identifier) + s.query.Scope.Declare(part.PatternBinding.Value.Identifier) } } diff --git a/packages/go/cypher/models/pgsql/translate/delete.go b/packages/go/cypher/models/pgsql/translate/delete.go index 09fc46067b..6475e77a44 100644 --- a/packages/go/cypher/models/pgsql/translate/delete.go +++ b/packages/go/cypher/models/pgsql/translate/delete.go @@ -34,7 +34,7 @@ func (s *Translator) translateDelete(scope *Scope, cypherDelete *cypher.Delete) if deleteFrame, err := scope.PushFrame(); err != nil { return err } else { - if identifierDeletion, err := s.mutations.AddDeletion(scope, typedExpression, deleteFrame); err != nil { + if identifierDeletion, err := s.query.CurrentPart().mutations.AddDeletion(scope, typedExpression, deleteFrame); err != nil { return err } else if boundProjections, err := buildVisibleScopeProjections(scope, nil); err != nil { return err @@ -46,7 +46,7 @@ func (s *Translator) translateDelete(scope *Scope, cypherDelete *cypher.Delete) return fmt.Errorf("expected aliased expression to have an alias set") } else if typedProjection.Alias.Value == typedExpression { // This is the projection being replaced by the assignment - if rewrittenProjections, err := buildProjection(typedExpression, identifierDeletion.UpdateBinding, scope); err != nil { + if rewrittenProjections, err := buildProjection(typedExpression, identifierDeletion.UpdateBinding, scope, scope.ReferenceFrame()); err != nil { return err } else { identifierDeletion.Projection = append(identifierDeletion.Projection, rewrittenProjections...) @@ -73,7 +73,7 @@ func (s *Translator) translateDelete(scope *Scope, cypherDelete *cypher.Delete) } func (s *Translator) buildDeletions(scope *Scope) error { - for _, identifierDeletion := range s.mutations.Deletions.Values() { + for _, identifierDeletion := range s.query.CurrentPart().mutations.Deletions.Values() { var ( sqlDelete = pgsql.Delete{ Using: []pgsql.FromClause{{ @@ -117,7 +117,7 @@ func (s *Translator) buildDeletions(scope *Scope) error { sqlDelete.Returning = identifierDeletion.Projection sqlDelete.Where = models.ValueOptional(joinConstraint.Expression) - s.query.Model.AddCTE(pgsql.CommonTableExpression{ + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: scope.CurrentFrameBinding().Identifier, }, diff --git a/packages/go/cypher/models/pgsql/translate/expansion.go b/packages/go/cypher/models/pgsql/translate/expansion.go index e819d7cb07..dccd4535e7 100644 --- a/packages/go/cypher/models/pgsql/translate/expansion.go +++ b/packages/go/cypher/models/pgsql/translate/expansion.go @@ -433,22 +433,24 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave pgsql.CompoundIdentifier{traversalStep.Edge.Identifier, pgsql.ColumnEndID}, pgsql.NewLiteral(1, pgsql.Int), pgsql.ExistsExpression{ - Subquery: pgsql.Query{ - Body: pgsql.Select{ - Projection: []pgsql.SelectItem{ - pgsql.NewLiteral(1, pgsql.Int), - }, - From: []pgsql.FromClause{{ - Source: pgsql.TableReference{ - Name: pgsql.CompoundIdentifier{model.EdgeTable}, - Binding: models.ValueOptional(traversalStep.Edge.Identifier), + Subquery: pgsql.Subquery{ + Query: pgsql.Query{ + Body: pgsql.Select{ + Projection: []pgsql.SelectItem{ + pgsql.NewLiteral(1, pgsql.Int), }, - }}, - Where: pgsql.NewBinaryExpression( - pgsql.CompoundIdentifier{traversalStep.RightNode.Identifier, pgsql.ColumnID}, - pgsql.OperatorEquals, - pgsql.CompoundIdentifier{traversalStep.Edge.Identifier, pgsql.ColumnStartID}, - ), + From: []pgsql.FromClause{{ + Source: pgsql.TableReference{ + Name: pgsql.CompoundIdentifier{model.EdgeTable}, + Binding: models.ValueOptional(traversalStep.Edge.Identifier), + }, + }}, + Where: pgsql.NewBinaryExpression( + pgsql.CompoundIdentifier{traversalStep.RightNode.Identifier, pgsql.ColumnID}, + pgsql.OperatorEquals, + pgsql.CompoundIdentifier{traversalStep.Edge.Identifier, pgsql.ColumnStartID}, + ), + }, }, }, Negated: false, @@ -474,22 +476,24 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave pgsql.NewLiteral(1, pgsql.Int), ), pgsql.ExistsExpression{ - Subquery: pgsql.Query{ - Body: pgsql.Select{ - Projection: []pgsql.SelectItem{ - pgsql.NewLiteral(1, pgsql.Int), - }, - From: []pgsql.FromClause{{ - Source: pgsql.TableReference{ - Name: pgsql.CompoundIdentifier{model.EdgeTable}, - Binding: models.ValueOptional(traversalStep.Edge.Identifier), + Subquery: pgsql.Subquery{ + Query: pgsql.Query{ + Body: pgsql.Select{ + Projection: []pgsql.SelectItem{ + pgsql.NewLiteral(1, pgsql.Int), }, - }}, - Where: pgsql.NewBinaryExpression( - pgsql.CompoundIdentifier{traversalStep.RightNode.Identifier, pgsql.ColumnID}, - pgsql.OperatorEquals, - pgsql.CompoundIdentifier{traversalStep.Edge.Identifier, pgsql.ColumnStartID}, - ), + From: []pgsql.FromClause{{ + Source: pgsql.TableReference{ + Name: pgsql.CompoundIdentifier{model.EdgeTable}, + Binding: models.ValueOptional(traversalStep.Edge.Identifier), + }, + }}, + Where: pgsql.NewBinaryExpression( + pgsql.CompoundIdentifier{traversalStep.RightNode.Identifier, pgsql.ColumnID}, + pgsql.OperatorEquals, + pgsql.CompoundIdentifier{traversalStep.Edge.Identifier, pgsql.ColumnStartID}, + ), + }, }, }, Negated: false, diff --git a/packages/go/cypher/models/pgsql/translate/expression.go b/packages/go/cypher/models/pgsql/translate/expression.go index a76020933d..28f7cdb23c 100644 --- a/packages/go/cypher/models/pgsql/translate/expression.go +++ b/packages/go/cypher/models/pgsql/translate/expression.go @@ -719,7 +719,7 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato return fmt.Errorf("unknown identifier %s", typedROperand) } else { switch boundLOperand.DataType { - case pgsql.NodeComposite: + case pgsql.NodeComposite, pgsql.ExpansionRootNode, pgsql.ExpansionTerminalNode: switch boundROperand.DataType { case pgsql.NodeComposite, pgsql.ExpansionRootNode, pgsql.ExpansionTerminalNode: default: @@ -730,7 +730,7 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato newExpression.LOperand = pgsql.CompoundIdentifier{typedLOperand, pgsql.ColumnID} newExpression.ROperand = pgsql.CompoundIdentifier{typedROperand, pgsql.ColumnID} - case pgsql.EdgeComposite: + case pgsql.EdgeComposite, pgsql.ExpansionEdge: switch boundROperand.DataType { case pgsql.EdgeComposite, pgsql.ExpansionEdge: default: diff --git a/packages/go/cypher/models/pgsql/translate/function.go b/packages/go/cypher/models/pgsql/translate/function.go new file mode 100644 index 0000000000..f833cc0982 --- /dev/null +++ b/packages/go/cypher/models/pgsql/translate/function.go @@ -0,0 +1,237 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package translate + +import ( + "fmt" + "github.com/specterops/bloodhound/cypher/models/cypher" + "github.com/specterops/bloodhound/cypher/models/pgsql" + "strings" +) + +func (s *Translator) translateFunction(typedExpression *cypher.FunctionInvocation) { + switch formattedName := strings.ToLower(typedExpression.Name); formattedName { + case cypher.IdentityFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if referenceArgument, err := PopFromBuilderAs[pgsql.Identifier](s.treeTranslator); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.CompoundIdentifier{referenceArgument, pgsql.ColumnID}) + } + + case cypher.LocalTimeFunction: + if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.TimeWithoutTimeZone); err != nil { + s.SetError(err) + } + + case cypher.LocalDateTimeFunction: + if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.TimestampWithoutTimeZone); err != nil { + s.SetError(err) + } + + case cypher.DateFunction: + if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.Date); err != nil { + s.SetError(err) + } + + case cypher.DateTimeFunction: + if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.TimestampWithTimeZone); err != nil { + s.SetError(err) + } + + case cypher.EdgeTypeFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else if identifier, isIdentifier := argument.(pgsql.Identifier); !isIdentifier { + s.SetErrorf("expected an identifier for the cypher function: %s but received %T", typedExpression.Name, argument) + } else { + s.treeTranslator.Push(pgsql.CompoundIdentifier{identifier, pgsql.ColumnKindID}) + } + + case cypher.NodeLabelsFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else if identifier, isIdentifier := argument.(pgsql.Identifier); !isIdentifier { + s.SetErrorf("expected an identifier for the cypher function: %s but received %T", typedExpression.Name, argument) + } else { + s.treeTranslator.Push(pgsql.CompoundIdentifier{identifier, pgsql.ColumnKindIDs}) + } + + case cypher.CountFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionCount, + Parameters: []pgsql.Expression{argument}, + CastType: pgsql.Int8, + }) + } + + case cypher.StringSplitToArrayFunction: + if typedExpression.NumArguments() != 2 { + s.SetError(fmt.Errorf("expected two arguments for cypher function %s", typedExpression.Name)) + } else if delimiter, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else if splitReference, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + if _, hasHint := GetTypeHint(splitReference); !hasHint { + // Do our best to coerce the type into text + if typedSplitRef, err := TypeCastExpression(splitReference, pgsql.Text); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionStringToArray, + Parameters: []pgsql.Expression{typedSplitRef, delimiter}, + CastType: pgsql.TextArray, + }) + } + } else { + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionStringToArray, + Parameters: []pgsql.Expression{splitReference, delimiter}, + CastType: pgsql.TextArray, + }) + } + } + + case cypher.ToLowerFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + // Rewrite the property lookup operator with a JSON text field lookup + propertyLookup.Operator = pgsql.OperatorJSONTextField + } + + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionToLower, + Parameters: []pgsql.Expression{argument}, + CastType: pgsql.Text, + }) + } + + case cypher.ListSizeFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + var functionCall pgsql.FunctionCall + + if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + // Ensure that the JSONB array length function receives the JSONB type + propertyLookup.Operator = pgsql.OperatorJSONField + + functionCall = pgsql.FunctionCall{ + Function: pgsql.FunctionJSONBArrayLength, + Parameters: []pgsql.Expression{argument}, + CastType: pgsql.Int, + } + } else { + functionCall = pgsql.FunctionCall{ + Function: pgsql.FunctionArrayLength, + Parameters: []pgsql.Expression{argument, pgsql.NewLiteral(1, pgsql.Int)}, + CastType: pgsql.Int, + } + } + + s.treeTranslator.Push(functionCall) + } + + case cypher.ToUpperFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + // Rewrite the property lookup operator with a JSON text field lookup + propertyLookup.Operator = pgsql.OperatorJSONTextField + } + + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionToUpper, + Parameters: []pgsql.Expression{argument}, + CastType: pgsql.Text, + }) + } + + case cypher.ToStringFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Text)) + } + + case cypher.ToIntegerFunction: + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Int8)) + } + + case cypher.CoalesceFunction: + if err := s.translateCoalesceFunction(typedExpression); err != nil { + s.SetError(err) + } + + case cypher.CollectFunction: + // TODO: This causes an implicit group by of all other projected arguments + if typedExpression.NumArguments() != 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + switch typedArgument := unwrapParenthetical(argument).(type) { + case pgsql.Identifier: + if binding, bound := s.query.Scope.Lookup(typedArgument); !bound { + s.SetError(fmt.Errorf("binding not found for collect function argument %s", typedExpression.Name)) + } else if bindingArrayType, err := binding.DataType.ToArrayType(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionArrayAggregate, + Parameters: []pgsql.Expression{argument}, + Distinct: typedExpression.Distinct, + CastType: bindingArrayType, + }) + } + + default: + s.SetError(fmt.Errorf("expected identifier for cypher function: %s", typedExpression.Name)) + } + } + + default: + s.SetErrorf("unknown cypher function: %s", typedExpression.Name) + } +} diff --git a/packages/go/cypher/models/pgsql/translate/model.go b/packages/go/cypher/models/pgsql/translate/model.go index 095691d80f..e258c5a95e 100644 --- a/packages/go/cypher/models/pgsql/translate/model.go +++ b/packages/go/cypher/models/pgsql/translate/model.go @@ -151,15 +151,93 @@ func (s *Pattern) CurrentPart() *PatternPart { } type Query struct { + Parts []*QueryPart + Scope *Scope +} + +func (s *Query) HasParts() bool { + return len(s.Parts) > 0 +} + +func (s *Query) CurrentPart() *QueryPart { + return s.Parts[len(s.Parts)-1] +} + +func (s *Query) PreparePart(allocateFrame bool) error { + newPart := &QueryPart{ + Model: &pgsql.Query{ + CommonTableExpressions: &pgsql.With{}, + }, + } + + if allocateFrame { + if frame, err := s.Scope.PushFrame(); err != nil { + return err + } else { + newPart.frame = frame + } + } + + s.Parts = append(s.Parts, newPart) + return nil +} + +type QueryPart struct { Model *pgsql.Query - Scope *Scope Updates []*Mutations OrderBy []pgsql.OrderBy Skip models.Optional[pgsql.Expression] Limit models.Optional[pgsql.Expression] + + frame *Frame + properties map[string]pgsql.Expression + pattern *Pattern + match *Match + projections *Projections + mutations *Mutations +} + +func (s *QueryPart) HasProjections() bool { + return s.projections != nil && len(s.projections.Items) > 0 } -func (s *Query) CurrentOrderBy() *pgsql.OrderBy { +func (s *QueryPart) PrepareProjections(distinct bool) { + s.projections = &Projections{ + Distinct: distinct, + } +} + +func (s *QueryPart) PrepareMutations() { + if s.mutations == nil { + s.mutations = NewMutations() + } +} + +func (s *QueryPart) HasMutations() bool { + return s.mutations != nil && s.mutations.Assignments.Len() > 0 +} + +func (s *QueryPart) HasDeletions() bool { + return s.mutations != nil && s.mutations.Deletions.Len() > 0 +} + +func (s *QueryPart) PrepareProjection() { + s.projections.Items = append(s.projections.Items, &Projection{}) +} + +func (s *QueryPart) CurrentProjection() *Projection { + return s.projections.Current() +} + +func (s *QueryPart) PrepareProperties() { + if s.properties != nil { + clear(s.properties) + } else { + s.properties = map[string]pgsql.Expression{} + } +} + +func (s *QueryPart) CurrentOrderBy() *pgsql.OrderBy { return &s.OrderBy[len(s.OrderBy)-1] } @@ -312,21 +390,19 @@ func (s *Mutations) AddKindRemoval(scope *Scope, targetIdentifier pgsql.Identifi return nil } -type ProjectionClause struct { - Distinct bool - Projections []*Projection -} - -func NewProjectionClause() *ProjectionClause { - return &ProjectionClause{} +type Projections struct { + Distinct bool + Frame *Frame + Items []*Projection + GroupBy []pgsql.SelectItem } -func (s *ProjectionClause) PushProjection() { - s.Projections = append(s.Projections, &Projection{}) +func (s *Projections) Add(projection *Projection) { + s.Items = append(s.Items, projection) } -func (s *ProjectionClause) CurrentProjection() *Projection { - return s.Projections[len(s.Projections)-1] +func (s *Projections) Current() *Projection { + return s.Items[len(s.Items)-1] } func extractIdentifierFromCypherExpression(expression cypher.Expression) (pgsql.Identifier, bool, error) { diff --git a/packages/go/cypher/models/pgsql/translate/node.go b/packages/go/cypher/models/pgsql/translate/node.go index 19ae8a86e8..cdabce38bb 100644 --- a/packages/go/cypher/models/pgsql/translate/node.go +++ b/packages/go/cypher/models/pgsql/translate/node.go @@ -24,16 +24,21 @@ import ( "github.com/specterops/bloodhound/cypher/models/pgsql" ) -func (s *Translator) translateNodePattern(scope *Scope, nodePattern *cypher.NodePattern, part *PatternPart) error { - if bindingResult, err := s.bindPatternExpression(scope, nodePattern, pgsql.NodeComposite); err != nil { +func (s *Translator) translateNodePattern(nodePattern *cypher.NodePattern) error { + var ( + queryPart = s.query.CurrentPart() + patternPart = queryPart.pattern.CurrentPart() + ) + + if bindingResult, err := s.bindPatternExpression(nodePattern, pgsql.NodeComposite); err != nil { return err - } else if err := s.translateNodePatternToStep(scope, part, bindingResult); err != nil { + } else if err := s.translateNodePatternToStep(patternPart, bindingResult); err != nil { return err } else { - if len(s.properties) > 0 { + if len(queryPart.properties) > 0 { var propertyConstraints pgsql.Expression - for key, value := range s.properties { + for key, value := range queryPart.properties { propertyConstraints = pgsql.OptionalAnd(propertyConstraints, pgsql.NewBinaryExpression( pgsql.NewPropertyLookup(pgsql.CompoundIdentifier{bindingResult.Binding.Identifier, pgsql.ColumnProperties}, pgsql.NewLiteral(key, pgsql.Text)), pgsql.OperatorEquals, @@ -41,7 +46,7 @@ func (s *Translator) translateNodePattern(scope *Scope, nodePattern *cypher.Node )) } - if err := part.Constraints.Constrain(pgsql.AsIdentifierSet(bindingResult.Binding.Identifier), propertyConstraints); err != nil { + if err := patternPart.Constraints.Constrain(pgsql.AsIdentifierSet(bindingResult.Binding.Identifier), propertyConstraints); err != nil { return err } } @@ -58,7 +63,7 @@ func (s *Translator) translateNodePattern(scope *Scope, nodePattern *cypher.Node kindIDsLiteral, ) - if err := part.Constraints.Constrain(pgsql.AsIdentifierSet(bindingResult.Binding.Identifier), expression); err != nil { + if err := patternPart.Constraints.Constrain(pgsql.AsIdentifierSet(bindingResult.Binding.Identifier), expression); err != nil { return err } } @@ -68,7 +73,7 @@ func (s *Translator) translateNodePattern(scope *Scope, nodePattern *cypher.Node return nil } -func (s *Translator) translateNodePatternToStep(scope *Scope, part *PatternPart, bindingResult BindingResult) error { +func (s *Translator) translateNodePatternToStep(part *PatternPart, bindingResult BindingResult) error { if part.IsTraversal { if numSteps := len(part.TraversalSteps); numSteps == 0 { // This is the traversal step's left node @@ -96,13 +101,13 @@ func (s *Translator) translateNodePatternToStep(scope *Scope, part *PatternPart, // This is part of a continuing pattern element chain. Inspect the previous edge pattern to see if this // is the terminal node of an expansion. if currentStep.Expansion.Set { - if stepFrame, err := scope.PushFrame(); err != nil { + if stepFrame, err := s.query.Scope.PushFrame(); err != nil { return err } else { currentStep.Expansion.Value.Frame = stepFrame } - if boundProjections, err := buildVisibleScopeProjections(scope, currentStep.Definitions); err != nil { + if boundProjections, err := buildVisibleScopeProjections(s.query.Scope, currentStep.Definitions); err != nil { return err } else { currentStep.Expansion.Value.Projection = boundProjections.Items @@ -117,13 +122,13 @@ func (s *Translator) translateNodePatternToStep(scope *Scope, part *PatternPart, bindingResult.Binding.Link(expansionBinding) } } else { - if stepFrame, err := scope.PushFrame(); err != nil { + if stepFrame, err := s.query.Scope.PushFrame(); err != nil { return err } else { currentStep.Frame = stepFrame } - if boundProjections, err := buildVisibleScopeProjections(scope, currentStep.Definitions); err != nil { + if boundProjections, err := buildVisibleScopeProjections(s.query.Scope, currentStep.Definitions); err != nil { return err } else { currentStep.Projection = boundProjections.Items @@ -132,7 +137,7 @@ func (s *Translator) translateNodePatternToStep(scope *Scope, part *PatternPart, } } else { // If this isn't a traversal - if nodeFrame, err := scope.PushFrame(); err != nil { + if nodeFrame, err := s.query.Scope.PushFrame(); err != nil { return err } else { part.NodeSelect.Frame = nodeFrame @@ -142,7 +147,7 @@ func (s *Translator) translateNodePatternToStep(scope *Scope, part *PatternPart, if bindingResult.AlreadyBound { part.NodeSelect.IsDefinition = false - if boundProjections, err := buildVisibleScopeProjections(scope, nil); err != nil { + if boundProjections, err := buildVisibleScopeProjections(s.query.Scope, nil); err != nil { return err } else { part.NodeSelect.Select.Projection = boundProjections.Items @@ -150,7 +155,7 @@ func (s *Translator) translateNodePatternToStep(scope *Scope, part *PatternPart, } else { part.NodeSelect.IsDefinition = true - if boundProjections, err := buildVisibleScopeProjections(scope, []*BoundIdentifier{bindingResult.Binding}); err != nil { + if boundProjections, err := buildVisibleScopeProjections(s.query.Scope, []*BoundIdentifier{bindingResult.Binding}); err != nil { return err } else { part.NodeSelect.Select.Projection = boundProjections.Items @@ -179,12 +184,12 @@ func consumeConstraintsFrom(visible *pgsql.IdentifierSet, trackers ...*Constrain return constraint, nil } -func (s *Translator) buildNodePattern(scope *Scope, part *PatternPart) error { +func (s *Translator) buildNodePattern(part *PatternPart) error { var ( nextSelect pgsql.Select ) - if part.NodeSelect.Frame.Previous != nil { + if part.NodeSelect.Frame.Previous != nil && (s.query.CurrentPart().frame == nil || part.NodeSelect.Frame.Previous.Binding.Identifier != s.query.CurrentPart().frame.Binding.Identifier) { nextSelect.From = append(nextSelect.From, pgsql.FromClause{ Source: pgsql.TableReference{ Name: pgsql.CompoundIdentifier{part.NodeSelect.Frame.Previous.Binding.Identifier}, @@ -211,7 +216,7 @@ func (s *Translator) buildNodePattern(scope *Scope, part *PatternPart) error { }) // Prepare the next select statement - s.query.Model.AddCTE(pgsql.CommonTableExpression{ + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: part.NodeSelect.Frame.Binding.Identifier, }, diff --git a/packages/go/cypher/models/pgsql/translate/pattern.go b/packages/go/cypher/models/pgsql/translate/pattern.go index 5b19038a82..c16a5f6393 100644 --- a/packages/go/cypher/models/pgsql/translate/pattern.go +++ b/packages/go/cypher/models/pgsql/translate/pattern.go @@ -18,7 +18,7 @@ package translate import ( "github.com/specterops/bloodhound/cypher/models" - cypher "github.com/specterops/bloodhound/cypher/models/cypher" + "github.com/specterops/bloodhound/cypher/models/cypher" "github.com/specterops/bloodhound/cypher/models/pgsql" ) @@ -27,19 +27,19 @@ type BindingResult struct { AlreadyBound bool } -func (s *Translator) bindPatternExpression(scope *Scope, cypherExpression cypher.Expression, dataType pgsql.DataType) (BindingResult, error) { +func (s *Translator) bindPatternExpression(cypherExpression cypher.Expression, dataType pgsql.DataType) (BindingResult, error) { if cypherBinding, hasCypherBinding, err := extractIdentifierFromCypherExpression(cypherExpression); err != nil { return BindingResult{}, err - } else if existingBinding, bound := scope.AliasedLookup(cypherBinding); bound { + } else if existingBinding, bound := s.query.Scope.AliasedLookup(cypherBinding); bound { return BindingResult{ Binding: existingBinding, AlreadyBound: true, }, nil - } else if binding, err := scope.DefineNew(dataType); err != nil { + } else if binding, err := s.query.Scope.DefineNew(dataType); err != nil { return BindingResult{}, err } else { if hasCypherBinding { - scope.Alias(cypherBinding, binding) + s.query.Scope.Alias(cypherBinding, binding) } return BindingResult{ @@ -51,7 +51,7 @@ func (s *Translator) bindPatternExpression(scope *Scope, cypherExpression cypher func (s *Translator) translatePatternPart(scope *Scope, patternPart *cypher.PatternPart) error { // We expect this to be a node select if there aren't enough pattern elements for a traversal - newPatternPart := s.pattern.NewPart() + newPatternPart := s.query.CurrentPart().pattern.NewPart() newPatternPart.IsTraversal = len(patternPart.PatternElements) > 1 newPatternPart.ShortestPath = patternPart.ShortestPathPattern newPatternPart.AllShortestPaths = patternPart.AllShortestPathsPattern @@ -73,22 +73,22 @@ func (s *Translator) translatePatternPart(scope *Scope, patternPart *cypher.Patt return nil } -func (s *Translator) buildPatternPart(scope *Scope, part *PatternPart) error { +func (s *Translator) buildPatternPart(part *PatternPart) error { if part.IsTraversal { - return s.buildPattern(scope, part) + return s.buildPattern(part) } else { - return s.buildNodePattern(scope, part) + return s.buildNodePattern(part) } } -func (s *Translator) buildPattern(scope *Scope, part *PatternPart) error { +func (s *Translator) buildPattern(part *PatternPart) error { for idx, traversalStep := range part.TraversalSteps { if traversalStep.Expansion.Set { if idx > 0 { if traversalStepQuery, err := s.buildExpansionPatternStep(part, traversalStep); err != nil { return err } else { - s.query.Model.AddCTE(pgsql.CommonTableExpression{ + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: traversalStep.Expansion.Value.Frame.Binding.Identifier, }, @@ -99,7 +99,7 @@ func (s *Translator) buildPattern(scope *Scope, part *PatternPart) error { if traversalStepQuery, err := s.buildExpansionPatternRoot(part, traversalStep); err != nil { return err } else { - s.query.Model.AddCTE(pgsql.CommonTableExpression{ + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: traversalStep.Expansion.Value.Frame.Binding.Identifier, }, @@ -111,7 +111,7 @@ func (s *Translator) buildPattern(scope *Scope, part *PatternPart) error { if traversalStepQuery, err := s.buildTraversalPatternStep(part, traversalStep); err != nil { return err } else { - s.query.Model.AddCTE(pgsql.CommonTableExpression{ + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: traversalStep.Frame.Binding.Identifier, }, @@ -122,7 +122,7 @@ func (s *Translator) buildPattern(scope *Scope, part *PatternPart) error { if traversalStepQuery, err := s.buildTraversalPatternRoot(part, traversalStep); err != nil { return err } else { - s.query.Model.AddCTE(pgsql.CommonTableExpression{ + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: traversalStep.Frame.Binding.Identifier, }, diff --git a/packages/go/cypher/models/pgsql/translate/predicate.go b/packages/go/cypher/models/pgsql/translate/predicate.go index ff72366e2c..2afef4c718 100644 --- a/packages/go/cypher/models/pgsql/translate/predicate.go +++ b/packages/go/cypher/models/pgsql/translate/predicate.go @@ -26,10 +26,10 @@ import ( func (s *Translator) translatePatternPredicate(scope *Scope) error { // Set the pattern frame - s.pattern.Frame = scope.CurrentFrame() + s.query.CurrentPart().pattern.Frame = scope.CurrentFrame() // All pattern predicates must be relationship patterns - newPatternPart := s.pattern.NewPart() + newPatternPart := s.query.CurrentPart().pattern.NewPart() newPatternPart.IsTraversal = true return nil @@ -54,18 +54,20 @@ func (s *Translator) buildOptimizedRelationshipExistPredicate(part *PatternPart, // explain analyze select * from node n0 where not exists(select 1 from edge e0 where e0.start_id = n0.id or e0.end_id = n0.id); s.treeTranslator.Push(pgsql.ExistsExpression{ - Subquery: pgsql.Query{ - Body: pgsql.Select{ - Projection: []pgsql.SelectItem{ - pgsql.NewLiteral(1, pgsql.Int), - }, - From: []pgsql.FromClause{{ - Source: pgsql.TableReference{ - Name: pgsql.CompoundIdentifier{pgsql.TableEdge}, - Binding: models.ValueOptional(traversalStep.Edge.Identifier), - }}, + Subquery: pgsql.Subquery{ + Query: pgsql.Query{ + Body: pgsql.Select{ + Projection: []pgsql.SelectItem{ + pgsql.NewLiteral(1, pgsql.Int), + }, + From: []pgsql.FromClause{{ + Source: pgsql.TableReference{ + Name: pgsql.CompoundIdentifier{pgsql.TableEdge}, + Binding: models.ValueOptional(traversalStep.Edge.Identifier), + }}, + }, + Where: whereClause, }, - Where: whereClause, }, }, }) @@ -74,14 +76,14 @@ func (s *Translator) buildOptimizedRelationshipExistPredicate(part *PatternPart, } func (s *Translator) buildPatternPredicate() error { - if numPatternParts := len(s.pattern.Parts); numPatternParts < 1 || numPatternParts > 1 { + if numPatternParts := len(s.query.CurrentPart().pattern.Parts); numPatternParts < 1 || numPatternParts > 1 { return fmt.Errorf("expected exactly one pattern part for pattern predicate but found: %d", numPatternParts) } var ( lastFrame *Frame - patternPart = s.pattern.Parts[0] + patternPart = s.query.CurrentPart().pattern.Parts[0] subQuery = pgsql.Query{ CommonTableExpressions: &pgsql.With{}, } @@ -162,8 +164,8 @@ func (s *Translator) buildPatternPredicate() error { }}, } - s.treeTranslator.Push(pgsql.Parenthetical{ - Expression: subQuery, + s.treeTranslator.Push(pgsql.Subquery{ + Query: subQuery, }) return nil diff --git a/packages/go/cypher/models/pgsql/translate/projection.go b/packages/go/cypher/models/pgsql/translate/projection.go index 5b3259dfbd..80d79923f8 100644 --- a/packages/go/cypher/models/pgsql/translate/projection.go +++ b/packages/go/cypher/models/pgsql/translate/projection.go @@ -18,7 +18,6 @@ package translate import ( "fmt" - "github.com/specterops/bloodhound/cypher/models" "github.com/specterops/bloodhound/cypher/models/pgsql" ) @@ -70,7 +69,7 @@ func buildExternalProjection(scope *Scope, projections []*Projection) (pgsql.Pro alias = projectedBinding.Alias.Value } - if builtProjection, err := buildProjection(alias, projectedBinding, scope); err != nil { + if builtProjection, err := buildProjection(alias, projectedBinding, scope, projectedBinding.LastProjection); err != nil { return nil, err } else { for _, buildProjectionItem := range builtProjection { @@ -115,7 +114,7 @@ func buildInternalProjection(scope *Scope, projectedBindings []*BoundIdentifier) projected[projectedBinding.Identifier] = struct{}{} // Build the identifier's projection - if newSelectItems, err := buildProjection(projectedBinding.Identifier, projectedBinding, scope); err != nil { + if newSelectItems, err := buildProjection(projectedBinding.Identifier, projectedBinding, scope, projectedBinding.LastProjection); err != nil { return BoundProjections{}, err } else { boundProjections.Items = append(boundProjections.Items, newSelectItems...) @@ -126,26 +125,45 @@ func buildInternalProjection(scope *Scope, projectedBindings []*BoundIdentifier) return boundProjections, nil } -func buildVisibleScopeProjections(scope *Scope, boundIdentifiers []*BoundIdentifier) (BoundProjections, error) { - if visibleBindings, err := scope.LookupBindings(scope.Visible().Slice()...); err != nil { - return BoundProjections{}, err - } else if projection, err := buildInternalProjection(scope, append(visibleBindings, boundIdentifiers...)); err != nil { +func buildVisibleScopeProjections(scope *Scope, newlyBound []*BoundIdentifier) (BoundProjections, error) { + currentFrame := scope.CurrentFrame() + + if visibleBindings, err := scope.LookupBindings(currentFrame.Known().Slice()...); err != nil { return BoundProjections{}, err } else { - for _, boundIdentifier := range boundIdentifiers { - scope.Declare(boundIdentifier.Identifier) - } + allVisibleIdentifiers := append(visibleBindings, newlyBound...) + + if projection, err := buildInternalProjection(scope, allVisibleIdentifiers); err != nil { + return BoundProjections{}, err + } else { + // Mark all new bound identifiers as visible so they do not get reconstructed again on reference + for _, boundIdentifier := range newlyBound { + currentFrame.Reveal(boundIdentifier.Identifier) + currentFrame.Export(boundIdentifier.Identifier) + } + + // Zip through all projected identifiers and update their last projecte frame + for _, boundIdentifier := range allVisibleIdentifiers { + boundIdentifier.LastProjection = currentFrame + } - return projection, nil + return projection, nil + } } } -func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope *Scope) ([]pgsql.SelectItem, error) { - referenceFrame := scope.ReferenceFrame() - +func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope *Scope, referenceFrame *Frame) ([]pgsql.SelectItem, error) { switch projected.DataType { + case pgsql.InlineProjection: + return []pgsql.SelectItem{ + &pgsql.AliasedExpression{ + Expression: pgsql.CompoundIdentifier{referenceFrame.Binding.Identifier, projected.Identifier}, + Alias: pgsql.AsOptionalIdentifier(alias), + }, + }, nil + case pgsql.ExpansionPath: - if scope.IsVisible(projected.Identifier) { + if scope.IsMaterialized(projected.Identifier) { return []pgsql.SelectItem{ &pgsql.AliasedExpression{ Expression: pgsql.CompoundIdentifier{referenceFrame.Binding.Identifier, projected.Identifier}, @@ -218,7 +236,7 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope * }, nil case pgsql.ExpansionRootNode, pgsql.ExpansionTerminalNode: - if scope.IsVisible(projected.Identifier) { + if scope.IsMaterialized(projected.Identifier) { return []pgsql.SelectItem{ &pgsql.AliasedExpression{ Expression: pgsql.CompoundIdentifier{referenceFrame.Binding.Identifier, projected.Identifier}, @@ -247,7 +265,7 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope * }, nil case pgsql.NodeComposite: - if scope.IsVisible(projected.Identifier) { + if scope.IsMaterialized(projected.Identifier) { return []pgsql.SelectItem{ &pgsql.AliasedExpression{ Expression: pgsql.CompoundIdentifier{referenceFrame.Binding.Identifier, projected.Identifier}, @@ -321,7 +339,7 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope * } case pgsql.EdgeComposite: - if scope.IsVisible(projected.Identifier) { + if scope.IsMaterialized(projected.Identifier) { return []pgsql.SelectItem{ &pgsql.AliasedExpression{ Expression: pgsql.CompoundIdentifier{referenceFrame.Binding.Identifier, projected.Identifier}, diff --git a/packages/go/cypher/models/pgsql/translate/query.go b/packages/go/cypher/models/pgsql/translate/query.go new file mode 100644 index 0000000000..cf1f950cd1 --- /dev/null +++ b/packages/go/cypher/models/pgsql/translate/query.go @@ -0,0 +1,182 @@ +package translate + +import ( + "fmt" + "github.com/specterops/bloodhound/cypher/models/cypher" + "github.com/specterops/bloodhound/cypher/models/pgsql" +) + +func (s *Translator) buildMultiPartSinglePartQuery(singlePartQuery *cypher.SinglePartQuery, cteChain []pgsql.CommonTableExpression) error { + // Prepend the CTE chain to the model's + currentPart := s.query.CurrentPart() + currentPart.Model.CommonTableExpressions.Expressions = append(cteChain, currentPart.Model.CommonTableExpressions.Expressions...) + + return nil +} + +func (s *Translator) buildSinglePartQuery(singlePartQuery *cypher.SinglePartQuery) error { + if s.query.CurrentPart().HasMutations() { + if err := s.translateUpdates(s.query.Scope); err != nil { + s.SetError(err) + } + + if err := s.buildUpdates(s.query.Scope); err != nil { + s.SetError(err) + } + } + + if s.query.CurrentPart().HasDeletions() { + if err := s.buildDeletions(s.query.Scope); err != nil { + s.SetError(err) + } + } + + // If there was no return specified end the CTE chain with a bare select + if singlePartQuery.Return == nil { + if literalReturn, err := pgsql.AsLiteral(1); err != nil { + s.SetError(err) + } else { + s.query.CurrentPart().Model.Body = pgsql.Select{ + Projection: []pgsql.SelectItem{literalReturn}, + } + } + } else if err := s.buildTailProjection(); err != nil { + s.SetError(err) + } + + return nil +} + +func (s *Translator) buildMultiPartQuery(singlePartQuery *cypher.SinglePartQuery) error { + var multipartCTEChain []pgsql.CommonTableExpression + + // In order to author the multipart query part we first have to wrap it in a + for _, part := range s.query.Parts[:len(s.query.Parts)-1] { + // If the part has an empty inner CTE, make sure to remove it otherwise the keyword will still render + if len(part.Model.CommonTableExpressions.Expressions) == 0 { + part.Model.CommonTableExpressions = nil + } + + // Autor the part as a nested CTE + nextCTE := pgsql.CommonTableExpression{ + Query: *part.Model, + } + + if part.frame != nil { + nextCTE.Alias = pgsql.TableAlias{ + Name: part.frame.Binding.Identifier, + } + } + + if inlineSelect, err := s.buildInlineProjection(part); err != nil { + return err + } else { + nextCTE.Query.Body = inlineSelect + } + + multipartCTEChain = append(multipartCTEChain, nextCTE) + } + + if err := s.buildMultiPartSinglePartQuery(singlePartQuery, multipartCTEChain); err != nil { + return err + } + + s.translation.Statement = *s.query.CurrentPart().Model + return nil +} + +func (s *Translator) translateWith() error { + currentPart := s.query.CurrentPart() + + if !currentPart.HasProjections() { + currentPart.frame.Exported.Clear() + } else { + var ( + projectedItems = pgsql.NewIdentifierSet() + aggregatedItems = pgsql.NewIdentifierSet() + ) + + // If an aggregation function is being used this invokes an implicit group by of non-function projections + for _, projectionItem := range currentPart.projections.Items { + switch typedSelectItem := projectionItem.SelectItem.(type) { + case pgsql.FunctionCall: + if pgsql.IsAggregateFunction(typedSelectItem.Function) { + aggregatedItems.Add(typedSelectItem.Function) + } + } + } + + for idx, projectionItem := range currentPart.projections.Items { + switch typedSelectItem := projectionItem.SelectItem.(type) { + case *pgsql.BinaryExpression: + return fmt.Errorf("unhandled case for with statement") + + case pgsql.CompoundIdentifier: + return fmt.Errorf("unhandled case for with statement") + + case pgsql.Identifier: + if !aggregatedItems.IsEmpty() && !aggregatedItems.Contains(typedSelectItem) { + currentPart.projections.GroupBy = append(currentPart.projections.GroupBy, typedSelectItem) + } + + if binding, isBound := s.query.Scope.Lookup(typedSelectItem); !isBound { + return fmt.Errorf("unable to lookup identifer %s for with statement", typedSelectItem) + } else { + // Track this projected item for scope pruning + projectedItems.Add(binding.Identifier) + + // Create a new projection that maps the identifier + currentPart.projections.Items[idx] = &Projection{ + SelectItem: pgsql.CompoundIdentifier{ + binding.LastProjection.Binding.Identifier, typedSelectItem, + }, + Alias: pgsql.AsOptionalIdentifier(binding.Identifier), + } + + // Assign the frame to the binding's last projection backref + binding.LastProjection = currentPart.frame + + // Reveal and export the identifier in the current multipart query part's frame + currentPart.frame.Reveal(binding.Identifier) + currentPart.frame.Export(binding.Identifier) + } + + default: + // If this is not an identifier then check if the alias is specified. If the alias is specified, this + // is a pure export (left-hand side is some other expression) and a new bound identifier is being + // introduced. + if projectionItem.Alias.Set { + if binding, isBound := s.query.Scope.AliasedLookup(projectionItem.Alias.Value); !isBound { + return fmt.Errorf("unable to lookup alias %s for with statement", projectionItem.Alias.Value) + } else { + // Track this projected item for scope pruning + projectedItems.Add(binding.Identifier) + + // Assign the frame to the binding's last projection backref + binding.LastProjection = currentPart.frame + + // Reveal and export the identifier in the current multipart query part's frame + currentPart.frame.Reveal(binding.Identifier) + currentPart.frame.Export(binding.Identifier) + + // Rewrite this projection's alias to use the internal binding + projectionItem.Alias.Value = binding.Identifier + } + } + } + + // Prune scope to only what's being exported by the with statement + currentPart.frame.Visible = projectedItems.Copy() + currentPart.frame.Exported = projectedItems.Copy() + } + } + + return nil +} + +func (s *Translator) translateMultiPartQueryPart(scope *Scope, part *cypher.MultiPartQueryPart) error { + queryPart := s.query.CurrentPart() + + // Unwind nested frames + return scope.UnwindToFrame(queryPart.frame) +} diff --git a/packages/go/cypher/models/pgsql/translate/relationship.go b/packages/go/cypher/models/pgsql/translate/relationship.go index a0586207c6..8e6ffbc368 100644 --- a/packages/go/cypher/models/pgsql/translate/relationship.go +++ b/packages/go/cypher/models/pgsql/translate/relationship.go @@ -24,18 +24,23 @@ import ( "github.com/specterops/bloodhound/cypher/models/pgsql" ) -func (s *Translator) translateRelationshipPattern(scope *Scope, relationshipPattern *cypher.RelationshipPattern, part *PatternPart) error { - if bindingResult, err := s.bindPatternExpression(scope, relationshipPattern, pgsql.EdgeComposite); err != nil { +func (s *Translator) translateRelationshipPattern(relationshipPattern *cypher.RelationshipPattern) error { + var ( + queryPart = s.query.CurrentPart() + patternPart = queryPart.pattern.CurrentPart() + ) + + if bindingResult, err := s.bindPatternExpression(relationshipPattern, pgsql.EdgeComposite); err != nil { return err } else { - if err := s.translateRelationshipPatternToStep(scope, bindingResult, part, relationshipPattern); err != nil { + if err := s.translateRelationshipPatternToStep(bindingResult, patternPart, relationshipPattern); err != nil { return err } - if len(s.properties) > 0 { + if len(queryPart.properties) > 0 { var propertyConstraints pgsql.Expression - for key, value := range s.properties { + for key, value := range queryPart.properties { propertyConstraints = pgsql.OptionalAnd(propertyConstraints, pgsql.NewBinaryExpression( pgsql.NewPropertyLookup(pgsql.CompoundIdentifier{bindingResult.Binding.Identifier, pgsql.ColumnProperties}, pgsql.NewLiteral(key, pgsql.Text)), pgsql.OperatorEquals, @@ -43,7 +48,7 @@ func (s *Translator) translateRelationshipPattern(scope *Scope, relationshipPatt )) } - if err := part.Constraints.Constrain(pgsql.AsIdentifierSet(bindingResult.Binding.Identifier), propertyConstraints); err != nil { + if err := patternPart.Constraints.Constrain(pgsql.AsIdentifierSet(bindingResult.Binding.Identifier), propertyConstraints); err != nil { return err } } @@ -64,7 +69,7 @@ func (s *Translator) translateRelationshipPattern(scope *Scope, relationshipPatt ) ) - if err := part.Constraints.Constrain(dependencies, expression); err != nil { + if err := patternPart.Constraints.Constrain(dependencies, expression); err != nil { return err } } @@ -74,7 +79,7 @@ func (s *Translator) translateRelationshipPattern(scope *Scope, relationshipPatt return nil } -func (s *Translator) translateRelationshipPatternToStep(scope *Scope, bindingResult BindingResult, part *PatternPart, relationshipPattern *cypher.RelationshipPattern) error { +func (s *Translator) translateRelationshipPatternToStep(bindingResult BindingResult, part *PatternPart, relationshipPattern *cypher.RelationshipPattern) error { var ( expansion models.Optional[Expansion] numSteps = len(part.TraversalSteps) @@ -107,7 +112,7 @@ func (s *Translator) translateRelationshipPatternToStep(scope *Scope, bindingRes // Set the edge type to an expansion of edges bindingResult.Binding.DataType = pgsql.ExpansionEdge - if expansionScopeBinding, err := scope.DefineNew(pgsql.ExpansionPattern); err != nil { + if expansionScopeBinding, err := s.query.Scope.DefineNew(pgsql.ExpansionPattern); err != nil { return err } else { // Link the edge to the expansion @@ -126,7 +131,7 @@ func (s *Translator) translateRelationshipPatternToStep(scope *Scope, bindingRes MaxDepth: models.PointerOptional(relationshipPattern.Range.EndIndex), }) - if expansionPathBinding, err := scope.DefineNew(pgsql.ExpansionPath); err != nil { + if expansionPathBinding, err := s.query.Scope.DefineNew(pgsql.ExpansionPath); err != nil { return err } else { // Link the path array to the expansion that declares it diff --git a/packages/go/cypher/models/pgsql/translate/renamer.go b/packages/go/cypher/models/pgsql/translate/renamer.go index bbd3bb48fe..9f078a4853 100644 --- a/packages/go/cypher/models/pgsql/translate/renamer.go +++ b/packages/go/cypher/models/pgsql/translate/renamer.go @@ -64,10 +64,21 @@ type IdentifierRewriter struct { scopeIdentifier pgsql.Identifier targets *pgsql.IdentifierSet - stack []pgsql.SyntaxNode + skipDepth int } func (s *IdentifierRewriter) enter(node pgsql.SyntaxNode) error { + // Quick check to compensate for unwanted rewriting of sub-query expressions. Since these + // can be nested we track depth instead a boolean for skipping AST elements. + switch node.(type) { + case pgsql.Subquery: + s.skipDepth += 1 + } + + if s.skipDepth > 0 { + return nil + } + switch typedExpression := node.(type) { case pgsql.Projection: for idx, projection := range typedExpression { @@ -193,7 +204,6 @@ func (s *IdentifierRewriter) enter(node pgsql.SyntaxNode) error { } } - s.stack = append(s.stack, node) return nil } @@ -201,12 +211,13 @@ func (s *IdentifierRewriter) Enter(node pgsql.SyntaxNode) { if err := s.enter(node); err != nil { s.SetError(err) } - - s.stack = append(s.stack, node) } -func (s *IdentifierRewriter) exit(_ pgsql.SyntaxNode) error { - s.stack = s.stack[:len(s.stack)-1] +func (s *IdentifierRewriter) exit(node pgsql.SyntaxNode) error { + switch node.(type) { + case pgsql.Subquery: + s.skipDepth -= 1 + } return nil } @@ -221,6 +232,7 @@ func NewIdentifierRewriter(scopeIdentifier pgsql.Identifier, targets *pgsql.Iden HierarchicalVisitor: walk.NewComposableHierarchicalVisitor[pgsql.SyntaxNode](), scopeIdentifier: scopeIdentifier, targets: targets, + skipDepth: 0, } } diff --git a/packages/go/cypher/models/pgsql/translate/tracking.go b/packages/go/cypher/models/pgsql/translate/tracking.go index 9c381210fa..5d09ceeff4 100644 --- a/packages/go/cypher/models/pgsql/translate/tracking.go +++ b/packages/go/cypher/models/pgsql/translate/tracking.go @@ -50,6 +50,8 @@ func (s IdentifierGenerator) NewIdentifier(dataType pgsql.DataType) (pgsql.Ident return pgsql.Identifier("s" + nextIDStr), nil case pgsql.ParameterIdentifier: return pgsql.Identifier("pi" + nextIDStr), nil + case pgsql.InlineProjection: + return pgsql.Identifier("i" + nextIDStr), nil default: return "", fmt.Errorf("identifier with data type %s does not have a prefix case", dataType) } @@ -238,9 +240,23 @@ func (s *ConstraintTracker) Constrain(dependencies *pgsql.IdentifierSet, constra // Frame represents a snapshot of all identifiers defined and visible in a given scope type Frame struct { + id int Previous *Frame Binding *BoundIdentifier Visible *pgsql.IdentifierSet + Exported *pgsql.IdentifierSet +} + +func (s *Frame) Known() *pgsql.IdentifierSet { + return s.Visible.Copy().MergeSet(s.Exported) +} + +func (s *Frame) Reveal(identifier pgsql.Identifier) { + s.Visible.Add(identifier) +} + +func (s *Frame) Export(identifier pgsql.Identifier) { + s.Exported.Add(identifier) } // Scope contains all identifier definitions and their temporal resolutions in a []*Frame field. @@ -253,6 +269,7 @@ type Frame struct { // all visible projections. This is required when disambiguating references that otherwise belong to // a frame. type Scope struct { + nextFrameID int stack []*Frame generator IdentifierGenerator aliases map[pgsql.Identifier]pgsql.Identifier @@ -261,6 +278,7 @@ type Scope struct { func NewScope() *Scope { return &Scope{ + nextFrameID: 0, generator: NewIdentifierGenerator(), aliases: map[pgsql.Identifier]pgsql.Identifier{}, definitions: map[pgsql.Identifier]*BoundIdentifier{}, @@ -291,15 +309,40 @@ func (s *Scope) ReferenceFrame() *Frame { return s.CurrentFrame() } -func (s *Scope) PopFrame() *Frame { - frame := s.stack[len(s.stack)-1] +func (s *Scope) PopFrame() error { + if len(s.stack) <= 0 { + return fmt.Errorf("no frame to pop") + } + s.stack = s.stack[:len(s.stack)-1] + return nil +} + +func (s *Scope) UnwindToFrame(frame *Frame) error { + found := false - return frame + for idx := len(s.stack) - 1; idx >= 0; idx-- { + if found = s.stack[idx].id == frame.id; found { + s.stack = s.stack[:idx+1] + break + } + } + + if !found { + return fmt.Errorf("unable to pop frame with ID %d", frame.id) + } + + return nil } func (s *Scope) PushFrame() (*Frame, error) { - newFrame := &Frame{} + newFrame := &Frame{ + id: s.nextFrameID, + Visible: pgsql.NewIdentifierSet(), + Exported: pgsql.NewIdentifierSet(), + } + + s.nextFrameID += 1 if nextScopeBinding, err := s.DefineNew(pgsql.Scope); err != nil { return nil, err @@ -312,7 +355,8 @@ func (s *Scope) PushFrame() (*Frame, error) { newFrame.Previous = s.stack[len(s.stack)-1] } - newFrame.Visible = currentFrame.Visible.Copy() + newFrame.Visible = currentFrame.Exported.Copy() + newFrame.Exported = currentFrame.Exported.Copy() } else { newFrame.Visible = pgsql.NewIdentifierSet() } @@ -329,8 +373,12 @@ func (s *Scope) CurrentFrameBinding() *BoundIdentifier { return nil } -func (s *Scope) IsVisible(identifier pgsql.Identifier) bool { - return s.CurrentFrame().Visible.Contains(identifier) +func (s *Scope) IsMaterialized(identifier pgsql.Identifier) bool { + if binding, isBound := s.definitions[identifier]; isBound { + return binding.LastProjection != nil + } + + return false } func (s *Scope) Visible() *pgsql.IdentifierSet { @@ -406,11 +454,12 @@ func (s *Scope) Define(identifier pgsql.Identifier, dataType pgsql.DataType) *Bo // will eagerly bind anonymous identifiers for traversal steps and rebind existing identifiers and their // aliases to prevent naming collisions. type BoundIdentifier struct { - Identifier pgsql.Identifier - Alias models.Optional[pgsql.Identifier] - Parameter models.Optional[*pgsql.Parameter] - Dependencies []*BoundIdentifier - DataType pgsql.DataType + Identifier pgsql.Identifier + Alias models.Optional[pgsql.Identifier] + Parameter models.Optional[*pgsql.Parameter] + LastProjection *Frame + Dependencies []*BoundIdentifier + DataType pgsql.DataType } func (s *BoundIdentifier) Aliased() pgsql.Identifier { diff --git a/packages/go/cypher/models/pgsql/translate/tracking_test.go b/packages/go/cypher/models/pgsql/translate/tracking_test.go new file mode 100644 index 0000000000..f15db2ca34 --- /dev/null +++ b/packages/go/cypher/models/pgsql/translate/tracking_test.go @@ -0,0 +1,28 @@ +package translate + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestScope(t *testing.T) { + var ( + scope = NewScope() + ) + + grandparent, err := scope.PushFrame() + require.Nil(t, err) + + parent, err := scope.PushFrame() + require.Nil(t, err) + + child, err := scope.PushFrame() + require.Nil(t, err) + + require.Equal(t, 0, grandparent.id) + require.Equal(t, 1, parent.id) + require.Equal(t, 2, child.id) + + require.Nil(t, scope.UnwindToFrame(parent)) + require.Equal(t, parent.id, scope.CurrentFrame().id) +} diff --git a/packages/go/cypher/models/pgsql/translate/translation.go b/packages/go/cypher/models/pgsql/translate/translation.go index c5fbac294b..698efc8622 100644 --- a/packages/go/cypher/models/pgsql/translate/translation.go +++ b/packages/go/cypher/models/pgsql/translate/translation.go @@ -42,7 +42,7 @@ func (s *Translator) translateRemoveItem(removeItem *cypher.RemoveItem) error { } else if binding, resolved := s.query.Scope.LookupString(variable.Symbol); !resolved { return fmt.Errorf("unable to find identifier %s", variable.Symbol) } else { - return s.mutations.AddKindRemoval(s.query.Scope, binding.Identifier, removeItem.KindMatcher.Kinds) + return s.query.CurrentPart().mutations.AddKindRemoval(s.query.Scope, binding.Identifier, removeItem.KindMatcher.Kinds) } } @@ -52,7 +52,7 @@ func (s *Translator) translateRemoveItem(removeItem *cypher.RemoveItem) error { } else if propertyLookup, err := decomposePropertyLookup(propertyLookupExpression); err != nil { return err } else { - return s.mutations.AddPropertyRemoval(s.query.Scope, propertyLookup) + return s.query.CurrentPart().mutations.AddPropertyRemoval(s.query.Scope, propertyLookup) } } @@ -142,7 +142,7 @@ func (s *Translator) translateSetItem(setItem *cypher.SetItem) error { } else if leftPropertyLookup, err := decomposePropertyLookup(leftOperand); err != nil { return err } else { - return s.mutations.AddPropertyAssignment(s.query.Scope, leftPropertyLookup, operator, rightOperand) + return s.query.CurrentPart().mutations.AddPropertyAssignment(s.query.Scope, leftPropertyLookup, operator, rightOperand) } case pgsql.OperatorKindAssignment: @@ -155,7 +155,7 @@ func (s *Translator) translateSetItem(setItem *cypher.SetItem) error { } else if kindList, isKindListLiteral := rightOperand.(pgsql.KindListLiteral); !isKindListLiteral { return fmt.Errorf("expected an identifier for kind list right operand but got: %T", rightOperand) } else { - return s.mutations.AddKindAssignment(s.query.Scope, targetIdentifier, kindList.Values) + return s.query.CurrentPart().mutations.AddKindAssignment(s.query.Scope, targetIdentifier, kindList.Values) } default: @@ -321,6 +321,22 @@ func (s *Translator) translateKindMatcher(kindMatcher *cypher.KindMatcher) error return nil } +func unwrapParenthetical(parenthetical pgsql.Expression) pgsql.Expression { + next := parenthetical + + for next != nil { + switch typedNext := next.(type) { + case pgsql.Parenthetical: + next = typedNext.Expression + + default: + return next + } + } + + return parenthetical +} + func (s *Translator) translateProjectionItem(scope *Scope, projectionItem *cypher.ProjectionItem) error { if alias, hasAlias, err := extractIdentifierFromCypherExpression(projectionItem); err != nil { return err @@ -329,38 +345,58 @@ func (s *Translator) translateProjectionItem(scope *Scope, projectionItem *cyphe } else if selectItem, isProjection := nextExpression.(pgsql.SelectItem); !isProjection { s.SetErrorf("invalid type for select item: %T", nextExpression) } else { - if hasAlias { - s.projections.CurrentProjection().SetAlias(alias) - } - - switch typedSelectItem := selectItem.(type) { + switch typedSelectItem := unwrapParenthetical(selectItem).(type) { case pgsql.Identifier: + // Identifier lookups will require a projection + s.query.CurrentPart().projections.Frame = s.query.Scope.CurrentFrame() + // If this is an identifier then assume the identifier as the projection alias since the translator // rewrites all identifiers if !hasAlias { if boundSelectItem, bound := scope.Lookup(typedSelectItem); !bound { return fmt.Errorf("invalid identifier: %s", typedSelectItem) } else { - s.projections.CurrentProjection().SetAlias(boundSelectItem.Aliased()) + s.query.CurrentPart().CurrentProjection().SetAlias(boundSelectItem.Aliased()) } } case *pgsql.BinaryExpression: + // Binary expressions are used when properties are returned from a result projection + // e.g. match (n) return n.prop if propertyLookup, isPropertyLookup := asPropertyLookup(typedSelectItem); isPropertyLookup { + // Property lookups will require a scope reference + s.query.CurrentPart().projections.Frame = s.query.Scope.CurrentFrame() + // Ensure that projections maintain the raw JSONB type of the field propertyLookup.Operator = pgsql.OperatorJSONField } + + default: + if hasAlias { + if _, isBound := s.query.Scope.AliasedLookup(alias); !isBound { + if newBinding, err := s.query.Scope.DefineNew(pgsql.InlineProjection); err != nil { + return err + } else { + // This binding is its own alias + s.query.Scope.Alias(alias, newBinding) + } + } + } + } + + if hasAlias { + s.query.CurrentPart().CurrentProjection().SetAlias(alias) } - s.projections.CurrentProjection().SelectItem = selectItem + s.query.CurrentPart().CurrentProjection().SelectItem = selectItem } return nil } -func (s *Translator) translateProjection(projection *cypher.Projection) error { - s.projections = NewProjectionClause() - s.projections.Distinct = projection.Distinct +func (s *Translator) prepareProjection(projection *cypher.Projection) error { + currentPart := s.query.CurrentPart() + currentPart.PrepareProjections(projection.Distinct) if projection.Skip != nil { if cypherLiteral, isLiteral := projection.Skip.Value.(*cypher.Literal); !isLiteral { @@ -368,7 +404,7 @@ func (s *Translator) translateProjection(projection *cypher.Projection) error { } else if pgLiteral, err := pgsql.AsLiteral(cypherLiteral.Value); err != nil { return err } else { - s.query.Skip = models.ValueOptional[pgsql.Expression](pgLiteral) + currentPart.Skip = models.ValueOptional[pgsql.Expression](pgLiteral) } } @@ -378,7 +414,7 @@ func (s *Translator) translateProjection(projection *cypher.Projection) error { } else if pgLiteral, err := pgsql.AsLiteral(cypherLiteral.Value); err != nil { return err } else { - s.query.Limit = models.ValueOptional[pgsql.Expression](pgLiteral) + currentPart.Limit = models.ValueOptional[pgsql.Expression](pgLiteral) } } diff --git a/packages/go/cypher/models/pgsql/translate/translator.go b/packages/go/cypher/models/pgsql/translate/translator.go index cb9d77baba..e7fe7e7b9a 100644 --- a/packages/go/cypher/models/pgsql/translate/translator.go +++ b/packages/go/cypher/models/pgsql/translate/translator.go @@ -18,9 +18,6 @@ package translate import ( "context" - "fmt" - "strings" - "github.com/specterops/bloodhound/cypher/models" "github.com/specterops/bloodhound/cypher/models/cypher" "github.com/specterops/bloodhound/cypher/models/pgsql" @@ -28,57 +25,13 @@ import ( "github.com/specterops/bloodhound/dawgs/graph" ) -type State int - -const ( - StateTranslatingStart State = iota - StateTranslatingPatternPart - StateTranslatingMatch - StateTranslatingCreate - StateTranslatingWhere - StateTranslatingProjection - StateTranslatingOrderBy - StateTranslatingUpdateClause - StateTranslatingPatternPredicate - StateTranslatingNestedExpression -) - -func (s State) String() string { - switch s { - case StateTranslatingStart: - return "start" - case StateTranslatingPatternPart: - return "pattern part" - case StateTranslatingMatch: - return "match clause" - case StateTranslatingWhere: - return "where clause" - case StateTranslatingProjection: - return "projection" - case StateTranslatingOrderBy: - return "order by" - case StateTranslatingPatternPredicate: - return "pattern predicate" - case StateTranslatingNestedExpression: - return "nested expression" - default: - return "" - } -} - type Translator struct { walk.HierarchicalVisitor[cypher.SyntaxNode] ctx context.Context kindMapper pgsql.KindMapper translation Result - state []State treeTranslator *ExpressionTreeTranslator - properties map[string]pgsql.Expression - pattern *Pattern - match *Match - projections *ProjectionClause - mutations *Mutations query *Query } @@ -95,24 +48,9 @@ func NewTranslator(ctx context.Context, kindMapper pgsql.KindMapper, parameters ctx: ctx, kindMapper: kindMapper, treeTranslator: NewExpressionTreeTranslator(), - properties: map[string]pgsql.Expression{}, - pattern: &Pattern{}, - } -} - -func (s *Translator) currentState() State { - return s.state[len(s.state)-1] -} - -func (s *Translator) pushState(state State) { - s.state = append(s.state, state) -} - -func (s *Translator) exitState(expectedState State) { - if currentState := s.currentState(); currentState != expectedState { - s.SetErrorf("expected state %s but found %s", expectedState, currentState) - } else { - s.state = s.state[:len(s.state)-1] + query: &Query{ + Scope: NewScope(), + }, } } @@ -121,43 +59,35 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { case *cypher.RegularQuery, *cypher.SingleQuery, *cypher.PatternElement, *cypher.Return, *cypher.Comparison, *cypher.Skip, *cypher.Limit, cypher.Operator, *cypher.ArithmeticExpression, *cypher.NodePattern, *cypher.RelationshipPattern, *cypher.Remove, *cypher.Set, - *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression, *cypher.PropertyLookup: + *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression, *cypher.PropertyLookup, + *cypher.Negation, *cypher.Create, *cypher.Where, *cypher.ListLiteral, + *cypher.FunctionInvocation, *cypher.Order, *cypher.RemoveItem, *cypher.SetItem, + *cypher.MapItem: // No operation for these syntax nodes - case *cypher.Negation: - s.pushState(StateTranslatingNestedExpression) - - case *cypher.SinglePartQuery: - s.query = &Query{ - Scope: NewScope(), - Model: &pgsql.Query{ - CommonTableExpressions: &pgsql.With{}, - }, + case *cypher.MultiPartQuery: + case *cypher.MultiPartQueryPart: + if err := s.query.PreparePart(true); err != nil { + s.SetError(err) } - s.mutations = NewMutations() + case *cypher.With: + s.SetError(nil) - case *cypher.Create: - s.pushState(StateTranslatingCreate) + case *cypher.SinglePartQuery: + if err := s.query.PreparePart(false); err != nil { + s.SetError(err) + } case *cypher.Match: - s.pushState(StateTranslatingMatch) - // Start with a fresh match and where clause. Instantiation of the where clause here is necessary since // cypher will store identifier constraints in the query pattern which precedes the query where clause. - s.pattern = &Pattern{} - s.match = &Match{ + s.query.CurrentPart().pattern = &Pattern{} + s.query.CurrentPart().match = &Match{ Scope: s.query.Scope, - Pattern: s.pattern, + Pattern: s.query.CurrentPart().pattern, } - case *cypher.Where: - // Track that we're in a where clause first - s.pushState(StateTranslatingWhere) - - // If there's a where AST node present in the cypher model we likely have an expression to translate - s.pushState(StateTranslatingNestedExpression) - case graph.Kinds: s.treeTranslator.Push(pgsql.KindListLiteral{ Values: typedExpression, @@ -199,30 +129,17 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { } } - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - s.treeTranslator.Push(binding.Parameter.Value) - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) - } + s.treeTranslator.Push(binding.Parameter.Value) case *cypher.Variable: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - if binding, resolved := s.query.Scope.LookupString(typedExpression.Symbol); !resolved { - s.SetErrorf("unable to find identifier %s", typedExpression.Symbol) - } else { - s.treeTranslator.Push(binding.Identifier) - } + identifier := pgsql.Identifier(typedExpression.Symbol) - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) + if binding, resolved := s.query.Scope.AliasedLookup(identifier); !resolved { + s.SetErrorf("unable to resolve or otherwise lookup identifer %s", identifier) + } else { + s.treeTranslator.Push(binding.Identifier) } - case *cypher.ListLiteral: - s.pushState(StateTranslatingNestedExpression) - case *cypher.Literal: literalValue := typedExpression.Value @@ -235,56 +152,33 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { s.SetError(err) } else { newLiteral.Null = typedExpression.Null - - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - s.treeTranslator.Push(newLiteral) - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) - } + s.treeTranslator.Push(newLiteral) } case *cypher.Parenthetical: - s.pushState(StateTranslatingNestedExpression) s.treeTranslator.PushParenthetical() - case *cypher.FunctionInvocation: - s.pushState(StateTranslatingNestedExpression) - - case *cypher.Order: - s.pushState(StateTranslatingOrderBy) - case *cypher.SortItem: - s.pushState(StateTranslatingNestedExpression) - - s.query.OrderBy = append(s.query.OrderBy, pgsql.OrderBy{ + s.query.CurrentPart().OrderBy = append(s.query.CurrentPart().OrderBy, pgsql.OrderBy{ Ascending: typedExpression.Ascending, }) case *cypher.Projection: - s.pushState(StateTranslatingProjection) - - if err := s.translateProjection(typedExpression); err != nil { + if err := s.prepareProjection(typedExpression); err != nil { s.SetError(err) } case *cypher.ProjectionItem: - s.pushState(StateTranslatingNestedExpression) - s.projections.PushProjection() + s.query.CurrentPart().PrepareProjection() case *cypher.PatternPredicate: - s.pushState(StateTranslatingPatternPredicate) - - s.pattern = &Pattern{} + s.query.CurrentPart().pattern = &Pattern{} if err := s.translatePatternPredicate(s.query.Scope); err != nil { s.SetError(err) } case *cypher.PatternPart: - s.pushState(StateTranslatingPatternPart) - if err := s.translatePatternPart(s.query.Scope, typedExpression); err != nil { s.SetError(err) } @@ -306,22 +200,13 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { } case *cypher.UpdatingClause: - s.pushState(StateTranslatingUpdateClause) - - case *cypher.RemoveItem: - s.pushState(StateTranslatingNestedExpression) - - case *cypher.Delete: - s.pushState(StateTranslatingNestedExpression) - - case *cypher.SetItem: - s.pushState(StateTranslatingNestedExpression) + s.query.CurrentPart().PrepareMutations() case *cypher.Properties: - clear(s.properties) + s.query.CurrentPart().PrepareProperties() - case *cypher.MapItem: - s.pushState(StateTranslatingNestedExpression) + case *cypher.Delete: + s.query.CurrentPart().PrepareMutations() default: s.SetErrorf("unable to translate cypher type: %T", expression) @@ -331,63 +216,46 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { func (s *Translator) Exit(expression cypher.SyntaxNode) { switch typedExpression := expression.(type) { case *cypher.NodePattern: - if err := s.translateNodePattern(s.query.Scope, typedExpression, s.pattern.CurrentPart()); err != nil { + if err := s.translateNodePattern(typedExpression); err != nil { s.SetError(err) } case *cypher.RelationshipPattern: - if err := s.translateRelationshipPattern(s.query.Scope, typedExpression, s.pattern.CurrentPart()); err != nil { + if err := s.translateRelationshipPattern(typedExpression); err != nil { s.SetError(err) } case *cypher.MapItem: - s.exitState(StateTranslatingNestedExpression) - if value, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) } else { - s.properties[typedExpression.Key] = value + s.query.CurrentPart().properties[typedExpression.Key] = value } case *cypher.PatternPredicate: - s.exitState(StateTranslatingPatternPredicate) - // Retire the predicate scope frames and build the predicate - for range s.pattern.CurrentPart().TraversalSteps { - s.query.Scope.PopFrame() - } - - if err := s.buildPatternPredicate(); err != nil { + if err := s.query.Scope.UnwindToFrame(s.query.CurrentPart().pattern.Frame); err != nil { + s.SetError(err) + } else if err := s.buildPatternPredicate(); err != nil { s.SetError(err) } case *cypher.RemoveItem: - s.exitState(StateTranslatingNestedExpression) - if err := s.translateRemoveItem(typedExpression); err != nil { s.SetError(err) } case *cypher.Delete: - s.exitState(StateTranslatingNestedExpression) - if err := s.translateDelete(s.query.Scope, typedExpression); err != nil { s.SetError(err) } case *cypher.SetItem: - s.exitState(StateTranslatingNestedExpression) - if err := s.translateSetItem(typedExpression); err != nil { s.SetError(err) } - case *cypher.UpdatingClause: - s.exitState(StateTranslatingUpdateClause) - case *cypher.ListLiteral: - s.exitState(StateTranslatingNestedExpression) - var ( numExpressions = len(typedExpression.Expressions()) literal = pgsql.ArrayLiteral{ @@ -420,12 +288,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.treeTranslator.Push(literal) } - case *cypher.Order: - s.exitState(StateTranslatingOrderBy) - case *cypher.SortItem: - s.exitState(StateTranslatingNestedExpression) - // Rewrite the order by constraints if lookupExpression, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -437,25 +300,17 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { propertyLookup.Operator = pgsql.OperatorJSONField } - s.query.CurrentOrderBy().Expression = lookupExpression + s.query.CurrentPart().CurrentOrderBy().Expression = lookupExpression } case *cypher.KindMatcher: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - if matcher, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(matcher) - } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) + if matcher, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(matcher) } case *cypher.Parenthetical: - s.exitState(StateTranslatingNestedExpression) - // Pull the sub-expression we wrap if wrappedExpression, err := s.treeTranslator.Pop(); err != nil { s.SetError(err) @@ -463,204 +318,11 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.SetError(err) } else { parenthetical.Expression = wrappedExpression - - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - s.treeTranslator.Push(*parenthetical) - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) - } + s.treeTranslator.Push(*parenthetical) } case *cypher.FunctionInvocation: - s.exitState(StateTranslatingNestedExpression) - - formattedName := strings.ToLower(typedExpression.Name) - - switch formattedName { - case cypher.IdentityFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if referenceArgument, err := PopFromBuilderAs[pgsql.Identifier](s.treeTranslator); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(pgsql.CompoundIdentifier{referenceArgument, pgsql.ColumnID}) - } - - case cypher.LocalTimeFunction: - if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.TimeWithoutTimeZone); err != nil { - s.SetError(err) - } - - case cypher.LocalDateTimeFunction: - if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.TimestampWithoutTimeZone); err != nil { - s.SetError(err) - } - - case cypher.DateFunction: - if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.Date); err != nil { - s.SetError(err) - } - - case cypher.DateTimeFunction: - if err := s.translateDateTimeFunctionCall(typedExpression, pgsql.TimestampWithTimeZone); err != nil { - s.SetError(err) - } - - case cypher.EdgeTypeFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else if identifier, isIdentifier := argument.(pgsql.Identifier); !isIdentifier { - s.SetErrorf("expected an identifier for the cypher function: %s but received %T", typedExpression.Name, argument) - } else { - s.treeTranslator.Push(pgsql.CompoundIdentifier{identifier, pgsql.ColumnKindID}) - } - - case cypher.NodeLabelsFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else if identifier, isIdentifier := argument.(pgsql.Identifier); !isIdentifier { - s.SetErrorf("expected an identifier for the cypher function: %s but received %T", typedExpression.Name, argument) - } else { - s.treeTranslator.Push(pgsql.CompoundIdentifier{identifier, pgsql.ColumnKindIDs}) - } - - case cypher.CountFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(pgsql.FunctionCall{ - Function: pgsql.FunctionCount, - Parameters: []pgsql.Expression{argument}, - CastType: pgsql.Int8, - }) - } - - case cypher.StringSplitToArrayFunction: - if typedExpression.NumArguments() != 2 { - s.SetError(fmt.Errorf("expected two arguments for cypher function %s", typedExpression.Name)) - } else if delimiter, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else if splitReference, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - if _, hasHint := GetTypeHint(splitReference); !hasHint { - // Do our best to coerce the type into text - if typedSplitRef, err := TypeCastExpression(splitReference, pgsql.Text); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(pgsql.FunctionCall{ - Function: pgsql.FunctionStringToArray, - Parameters: []pgsql.Expression{typedSplitRef, delimiter}, - CastType: pgsql.TextArray, - }) - } - } else { - s.treeTranslator.Push(pgsql.FunctionCall{ - Function: pgsql.FunctionStringToArray, - Parameters: []pgsql.Expression{splitReference, delimiter}, - CastType: pgsql.TextArray, - }) - } - } - - case cypher.ToLowerFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { - // Rewrite the property lookup operator with a JSON text field lookup - propertyLookup.Operator = pgsql.OperatorJSONTextField - } - - s.treeTranslator.Push(pgsql.FunctionCall{ - Function: pgsql.FunctionToLower, - Parameters: []pgsql.Expression{argument}, - CastType: pgsql.Text, - }) - } - - case cypher.ListSizeFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - var functionCall pgsql.FunctionCall - - if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { - // Ensure that the JSONB array length function receives the JSONB type - propertyLookup.Operator = pgsql.OperatorJSONField - - functionCall = pgsql.FunctionCall{ - Function: pgsql.FunctionJSONBArrayLength, - Parameters: []pgsql.Expression{argument}, - CastType: pgsql.Int, - } - } else { - functionCall = pgsql.FunctionCall{ - Function: pgsql.FunctionArrayLength, - Parameters: []pgsql.Expression{argument, pgsql.NewLiteral(1, pgsql.Int)}, - CastType: pgsql.Int, - } - } - - s.treeTranslator.Push(functionCall) - } - - case cypher.ToUpperFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { - // Rewrite the property lookup operator with a JSON text field lookup - propertyLookup.Operator = pgsql.OperatorJSONTextField - } - - s.treeTranslator.Push(pgsql.FunctionCall{ - Function: pgsql.FunctionToUpper, - Parameters: []pgsql.Expression{argument}, - CastType: pgsql.Text, - }) - } - - case cypher.ToStringFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Text)) - } - - case cypher.ToIntegerFunction: - if typedExpression.NumArguments() != 1 { - s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) - } else if argument, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Int8)) - } - - case cypher.CoalesceFunction: - if err := s.translateCoalesceFunction(typedExpression); err != nil { - s.SetError(err) - } - - default: - s.SetErrorf("unknown cypher function: %s", typedExpression.Name) - } + s.translateFunction(typedExpression) case *cypher.UnaryAddOrSubtractExpression: if operand, err := s.treeTranslator.Pop(); err != nil { @@ -673,72 +335,57 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case *cypher.Negation: - s.exitState(StateTranslatingNestedExpression) - - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - if operand, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else { - for cursor := operand; cursor != nil; { - switch typedCursor := cursor.(type) { - case pgsql.Parenthetical: - // Unwrap parentheticals - cursor = typedCursor.Expression - continue - - case *pgsql.BinaryExpression: - switch typedCursor.Operator { - case pgsql.OperatorLike, pgsql.OperatorILike: - // If this is a string comparison operation then the negation requires wrapping the - // operand references in coalesce functions. While this will kick out index acceleration - // the negation will already damage the query planner's ability to utilize an index lookup. - - if leftPropertyLookup, isPropertyLookup := asPropertyLookup(typedCursor.LOperand); isPropertyLookup { - typedCursor.LOperand = pgsql.FunctionCall{ - Function: pgsql.FunctionCoalesce, - Parameters: []pgsql.Expression{ - leftPropertyLookup, - pgsql.NewLiteral("", pgsql.Text), - }, - CastType: pgsql.Text, - } + if operand, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + for cursor := operand; cursor != nil; { + switch typedCursor := cursor.(type) { + case pgsql.Parenthetical: + // Unwrap parentheticals + cursor = typedCursor.Expression + continue + + case *pgsql.BinaryExpression: + switch typedCursor.Operator { + case pgsql.OperatorLike, pgsql.OperatorILike: + // If this is a string comparison operation then the negation requires wrapping the + // operand references in coalesce functions. While this will kick out index acceleration + // the negation will already damage the query planner's ability to utilize an index lookup. + + if leftPropertyLookup, isPropertyLookup := asPropertyLookup(typedCursor.LOperand); isPropertyLookup { + typedCursor.LOperand = pgsql.FunctionCall{ + Function: pgsql.FunctionCoalesce, + Parameters: []pgsql.Expression{ + leftPropertyLookup, + pgsql.NewLiteral("", pgsql.Text), + }, + CastType: pgsql.Text, } + } - if rightPropertyLookup, isPropertyLookup := asPropertyLookup(typedCursor.ROperand); isPropertyLookup { - typedCursor.ROperand = pgsql.FunctionCall{ - Function: pgsql.FunctionCoalesce, - Parameters: []pgsql.Expression{ - rightPropertyLookup, - pgsql.NewLiteral("", pgsql.Text), - }, - CastType: pgsql.Text, - } + if rightPropertyLookup, isPropertyLookup := asPropertyLookup(typedCursor.ROperand); isPropertyLookup { + typedCursor.ROperand = pgsql.FunctionCall{ + Function: pgsql.FunctionCoalesce, + Parameters: []pgsql.Expression{ + rightPropertyLookup, + pgsql.NewLiteral("", pgsql.Text), + }, + CastType: pgsql.Text, } } } - - break } - s.treeTranslator.Push(&pgsql.UnaryExpression{ - Operator: pgsql.OperatorNot, - Operand: operand, - }) + break } - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) + s.treeTranslator.Push(&pgsql.UnaryExpression{ + Operator: pgsql.OperatorNot, + Operand: operand, + }) } - case *cypher.PatternPart: - s.exitState(StateTranslatingPatternPart) - case *cypher.Where: - // Validate state transitions - s.exitState(StateTranslatingNestedExpression) - s.exitState(StateTranslatingWhere) - // Assign the last operands as identifier set constraints if err := s.treeTranslator.PopRemainingExpressionsAsConstraints(); err != nil { s.SetError(err) @@ -748,106 +395,63 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.translatePropertyLookup(typedExpression) case *cypher.PartialComparison: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.Operator(typedExpression.Operator)); err != nil { - s.SetError(err) - } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) + if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.Operator(typedExpression.Operator)); err != nil { + s.SetError(err) } case *cypher.PartialArithmeticExpression: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.Operator(typedExpression.Operator)); err != nil { - s.SetError(err) - } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) + if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.Operator(typedExpression.Operator)); err != nil { + s.SetError(err) } case *cypher.Disjunction: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - for idx := 0; idx < typedExpression.Len()-1; idx++ { - if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorOr); err != nil { - s.SetError(err) - } + for idx := 0; idx < typedExpression.Len()-1; idx++ { + if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorOr); err != nil { + s.SetError(err) } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) } case *cypher.Conjunction: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - for idx := 0; idx < typedExpression.Len()-1; idx++ { - if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorAnd); err != nil { - s.SetError(err) - } + for idx := 0; idx < typedExpression.Len()-1; idx++ { + if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorAnd); err != nil { + s.SetError(err) } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) } - case *cypher.ProjectionItem: - s.exitState(StateTranslatingNestedExpression) + case *cypher.Projection: + s.SetError(nil) + case *cypher.ProjectionItem: if err := s.translateProjectionItem(s.query.Scope, typedExpression); err != nil { s.SetError(err) } - case *cypher.Projection: - s.exitState(StateTranslatingProjection) - - case *cypher.Return: - - case *cypher.Create: - s.exitState(StateTranslatingCreate) - case *cypher.Match: - s.exitState(StateTranslatingMatch) - - if err := s.buildMatch(s.match.Scope); err != nil { + if err := s.buildMatch(); err != nil { s.SetError(err) } - case *cypher.SinglePartQuery: - if s.mutations.Assignments.Len() > 0 { - if err := s.translateUpdates(s.query.Scope); err != nil { - s.SetError(err) - } - - if err := s.buildUpdates(s.query.Scope); err != nil { - s.SetError(err) - } + case *cypher.With: + if err := s.translateWith(); err != nil { + s.SetError(err) } - if s.mutations.Deletions.Len() > 0 { - if err := s.buildDeletions(s.query.Scope); err != nil { - s.SetError(err) - } + case *cypher.MultiPartQueryPart: + if err := s.translateMultiPartQueryPart(s.query.Scope, typedExpression); err != nil { + s.SetError(err) } - // If there was no return specified end the CTE chain with a bare select - if typedExpression.Return == nil { - if literalReturn, err := pgsql.AsLiteral(1); err != nil { - s.SetError(err) - } else { - s.query.Model.Body = pgsql.Select{ - Projection: []pgsql.SelectItem{literalReturn}, - } - } - } else if err := s.buildProjection(s.query.Scope); err != nil { + case *cypher.SinglePartQuery: + if err := s.buildSinglePartQuery(typedExpression); err != nil { s.SetError(err) } - s.translation.Statement = *s.query.Model + s.translation.Statement = *s.query.CurrentPart().Model + + case *cypher.MultiPartQuery: + if err := s.buildMultiPartQuery(typedExpression.SinglePartQuery); err != nil { + s.SetError(err) + } } } diff --git a/packages/go/cypher/models/pgsql/translate/update.go b/packages/go/cypher/models/pgsql/translate/update.go index 0f7b42e080..a148f30022 100644 --- a/packages/go/cypher/models/pgsql/translate/update.go +++ b/packages/go/cypher/models/pgsql/translate/update.go @@ -24,7 +24,7 @@ import ( ) func (s *Translator) translateUpdates(scope *Scope) error { - for _, identifierMutation := range s.mutations.Assignments.Values() { + for _, identifierMutation := range s.query.CurrentPart().mutations.Assignments.Values() { if stepFrame, err := s.query.Scope.PushFrame(); err != nil { return err } else { @@ -40,7 +40,7 @@ func (s *Translator) translateUpdates(scope *Scope) error { return fmt.Errorf("expected aliased expression to have an alias set") } else if typedProjection.Alias.Value == identifierMutation.TargetBinding.Identifier { // This is the projection being replaced by the assignment - if rewrittenProjections, err := buildProjection(identifierMutation.TargetBinding.Identifier, identifierMutation.UpdateBinding, scope); err != nil { + if rewrittenProjections, err := buildProjection(identifierMutation.TargetBinding.Identifier, identifierMutation.UpdateBinding, scope, scope.ReferenceFrame()); err != nil { return err } else { identifierMutation.Projection = append(identifierMutation.Projection, rewrittenProjections...) @@ -64,7 +64,7 @@ func (s *Translator) translateUpdates(scope *Scope) error { } func (s *Translator) buildUpdates(scope *Scope) error { - for _, identifierMutation := range s.mutations.Assignments.Values() { + for _, identifierMutation := range s.query.CurrentPart().mutations.Assignments.Values() { sqlUpdate := pgsql.Update{ From: []pgsql.FromClause{{ Source: pgsql.TableReference{ @@ -297,7 +297,7 @@ func (s *Translator) buildUpdates(scope *Scope) error { sqlUpdate.Returning = identifierMutation.Projection sqlUpdate.Where = models.ValueOptional(joinConstraint.Expression) - s.query.Model.AddCTE(pgsql.CommonTableExpression{ + s.query.CurrentPart().Model.AddCTE(pgsql.CommonTableExpression{ Alias: pgsql.TableAlias{ Name: identifierMutation.Frame.Binding.Identifier, }, diff --git a/packages/go/cypher/models/walk/walk.go b/packages/go/cypher/models/walk/walk.go index c2f9a75176..a337f75eb7 100644 --- a/packages/go/cypher/models/walk/walk.go +++ b/packages/go/cypher/models/walk/walk.go @@ -54,8 +54,10 @@ func (s *cancelableErrorHandler) SetDone() { } func (s *cancelableErrorHandler) SetError(err error) { - s.errs = append(s.errs, err) - s.done = true + if err != nil { + s.errs = append(s.errs, err) + s.done = true + } } func (s *cancelableErrorHandler) SetErrorf(format string, args ...any) { diff --git a/packages/go/cypher/models/walk/walk_cypher.go b/packages/go/cypher/models/walk/walk_cypher.go index f7f5a35c6d..1191bb728d 100644 --- a/packages/go/cypher/models/walk/walk_cypher.go +++ b/packages/go/cypher/models/walk/walk_cypher.go @@ -224,14 +224,14 @@ func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], er Node: node, } - if typedNode.Where != nil { - nextCursor.AddBranches(typedNode.Where) - } - if typedNode.Projection != nil { nextCursor.AddBranches(typedNode.Projection) } + if typedNode.Where != nil { + nextCursor.AddBranches(typedNode.Where) + } + return nextCursor, nil case *cypher.Quantifier: diff --git a/packages/go/cypher/models/walk/walk_pgsql.go b/packages/go/cypher/models/walk/walk_pgsql.go index 29924ae140..f365bfe973 100644 --- a/packages/go/cypher/models/walk/walk_pgsql.go +++ b/packages/go/cypher/models/walk/walk_pgsql.go @@ -331,6 +331,12 @@ func newSQLWalkCursor(node pgsql.SyntaxNode) (*Cursor[pgsql.SyntaxNode], error) }, nil } + case pgsql.Subquery: + return &Cursor[pgsql.SyntaxNode]{ + Node: node, + Branches: []pgsql.SyntaxNode{typedNode.Query}, + }, nil + default: return nil, fmt.Errorf("unable to negotiate sql type %T into a translation cursor", node) } diff --git a/packages/go/dawgs/drivers/pg/tooling.go b/packages/go/dawgs/drivers/pg/tooling.go index 32f9fdb65c..18493a9ab8 100644 --- a/packages/go/dawgs/drivers/pg/tooling.go +++ b/packages/go/dawgs/drivers/pg/tooling.go @@ -17,7 +17,6 @@ package pg import ( - "log/slog" "regexp" "sync" @@ -53,7 +52,6 @@ type queryHook struct { func (s *queryHook) Execute(query string, arguments ...any) { switch s.action { case actionTrace: - slog.Info("Here") } }