Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CySQL - Multi-Part Query Support Fixes #1112

Open
wants to merge 1 commit into
base: stage/v7.0.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions packages/go/cypher/frontend/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,44 +353,61 @@ func (s *SinglePartQueryVisitor) ExitOC_UpdatingClause(ctx *parser.OC_UpdatingCl
type MultiPartQueryVisitor struct {
BaseVisitor

Query *cypher.MultiPartQuery
Query *cypher.MultiPartQuery
partIdx int
}

func NewMultiPartQueryVisitor() *MultiPartQueryVisitor {
return &MultiPartQueryVisitor{
Query: cypher.NewMultiPartQuery(),
Query: cypher.NewMultiPartQuery(),
partIdx: 0,
}
}

func (s *MultiPartQueryVisitor) EnterOC_ReadingClause(ctx *parser.OC_ReadingClauseContext) {
// If the part index is equal to the length of parts then this signifies that a new query part
// is required. We do not advance the index here - this is done with the following `with`
// cypher AST component
if len(s.Query.Parts) == s.partIdx {
s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart())
}

s.ctx.Enter(NewReadingClauseVisitor())
}

func (s *MultiPartQueryVisitor) ExitOC_ReadingClause(ctx *parser.OC_ReadingClauseContext) {
part := cypher.NewMultiPartQueryPart()
part.AddReadingClause(s.ctx.Exit().(*ReadingClauseVisitor).ReadingClause)
s.Query.Parts = append(s.Query.Parts, part)
s.Query.CurrentPart().AddReadingClause(s.ctx.Exit().(*ReadingClauseVisitor).ReadingClause)
}

func (s *MultiPartQueryVisitor) EnterOC_UpdatingClause(ctx *parser.OC_UpdatingClauseContext) {
if len(s.Query.Parts) == s.partIdx {
s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart())
}

s.ctx.Enter(NewUpdatingClauseVisitor())
}

func (s *MultiPartQueryVisitor) ExitOC_UpdatingClause(ctx *parser.OC_UpdatingClauseContext) {
// Make sure to mark that this multipart query part contains a mutation (non-read operation)
s.ctx.HasMutation = true
part := cypher.NewMultiPartQueryPart()
part.AddUpdatingClause(s.ctx.Exit().(*UpdatingClauseVisitor).UpdatingClause)
s.Query.Parts = append(s.Query.Parts, part)

s.Query.CurrentPart().AddUpdatingClause(s.ctx.Exit().(*UpdatingClauseVisitor).UpdatingClause)
}

func (s *MultiPartQueryVisitor) EnterOC_With(ctx *parser.OC_WithContext) {
if len(s.Query.Parts) == s.partIdx {
s.Query.Parts = append(s.Query.Parts, cypher.NewMultiPartQueryPart())
}

s.ctx.Enter(NewWithVisitor())
}

func (s *MultiPartQueryVisitor) ExitOC_With(ctx *parser.OC_WithContext) {
part := cypher.NewMultiPartQueryPart()
part.With = s.ctx.Exit().(*WithVisitor).With
s.Query.Parts = append(s.Query.Parts, part)
s.Query.CurrentPart().With = s.ctx.Exit().(*WithVisitor).With

// Advance the part index so a new multipart query part gets allocated for the next reading
// or updating clause
s.partIdx += 1
}

func (s *MultiPartQueryVisitor) EnterOC_SinglePartQuery(ctx *parser.OC_SinglePartQueryContext) {
Expand Down
1 change: 1 addition & 0 deletions packages/go/cypher/models/cypher/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const (
ToIntegerFunction = "toint"
ListSizeFunction = "size"
CoalesceFunction = "coalesce"
CollectFunction = "collect"

// ITTC - Instant Type; Temporal Component (https://neo4j.com/docs/cypher-manual/current/functions/temporal/)
ITTCYear = "year"
Expand Down
4 changes: 4 additions & 0 deletions packages/go/cypher/models/cypher/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ type MultiPartQuery struct {
SinglePartQuery *SinglePartQuery
}

func (s *MultiPartQuery) CurrentPart() *MultiPartQueryPart {
return s.Parts[len(s.Parts)-1]
}

func NewMultiPartQuery() *MultiPartQuery {
return &MultiPartQuery{}
}
Expand Down
5 changes: 4 additions & 1 deletion packages/go/cypher/models/pgsql/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
}

case pgsql.ExistsExpression:
exprStack = append(exprStack, pgsql.FormattingLiteral(")"), typedNextExpr.Subquery, pgsql.FormattingLiteral("exists ("))
exprStack = append(exprStack, typedNextExpr.Subquery, pgsql.FormattingLiteral("exists "))

if typedNextExpr.Negated {
exprStack = append(exprStack, pgsql.FormattingLiteral("not "))
Expand All @@ -517,6 +517,9 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
}
}

case pgsql.Subquery:
exprStack = append(exprStack, pgsql.FormattingLiteral(")"), typedNextExpr.Query, pgsql.FormattingLiteral("("))

default:
return fmt.Errorf("unable to format pgsql node type: %T", nextExpr)
}
Expand Down
10 changes: 10 additions & 0 deletions packages/go/cypher/models/pgsql/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,13 @@ const (
FunctionEdgesToPath Identifier = "edges_to_path"
FunctionExtract Identifier = "extract"
)

func IsAggregateFunction(function Identifier) bool {
switch function {
case FunctionCount, FunctionArrayAggregate:
return true

default:
return false
}
}
4 changes: 4 additions & 0 deletions packages/go/cypher/models/pgsql/identifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ func AsIdentifierSet(identifiers ...Identifier) *IdentifierSet {
return newSet
}

func (s *IdentifierSet) Clear() {
clear(s.identifiers)
}

func (s *IdentifierSet) Len() int {
return len(s.identifiers)
}
Expand Down
27 changes: 13 additions & 14 deletions packages/go/cypher/models/pgsql/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ type Subquery struct {
Query Query
}

func (s Subquery) NodeType() string {
return "subquery"
}

func (s Subquery) AsExpression() Expression {
return s
}

// not <expr>
type UnaryExpression struct {
Operator Expression
Expand Down Expand Up @@ -321,6 +329,10 @@ type Parenthetical struct {
Expression Expression
}

func (s Parenthetical) AsSelectItem() SelectItem {
return s
}

func (s Parenthetical) NodeType() string {
return "parenthetical"
}
Expand Down Expand Up @@ -1022,7 +1034,6 @@ func (s Select) NodeType() string {
// select 1
// union
// select 2;

type SetOperation struct {
Operator Operator
LOperand SetExpression
Expand All @@ -1047,7 +1058,7 @@ func (s SetOperation) NodeType() string {
//
// [not] exists(<query>)
type ExistsExpression struct {
Subquery Query
Subquery Subquery
Negated bool
}

Expand Down Expand Up @@ -1132,18 +1143,6 @@ func (s Query) NodeType() string {
return "query"
}

func BinaryExpressionJoinTyped(optional Expression, operator Operator, conjoined *BinaryExpression) *BinaryExpression {
if optional == nil {
return conjoined
}

return NewBinaryExpression(
conjoined,
operator,
optional,
)
}

func BinaryExpressionJoin(optional Expression, operator Operator, conjoined Expression) Expression {
if optional == nil {
return conjoined
Expand Down
1 change: 1 addition & 0 deletions packages/go/cypher/models/pgsql/pgtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ const (
TimestampWithoutTimeZone DataType = "timestamp without time zone"

Scope DataType = "scope"
InlineProjection DataType = "inline_projection"
ParameterIdentifier DataType = "parameter_identifier"
ExpansionPattern DataType = "expansion_pattern"
ExpansionPath DataType = "expansion_path"
Expand Down
Loading
Loading