From b4bd50cffb5470ef0836d3ca1df88f650b9a4344 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 11 Dec 2024 12:18:51 -0800 Subject: [PATCH] BED-5091 - Merge stage `v6.3.0` (#1009) * set ff to true for 6.3.0 migration file (#999) * CySQL Updates (#1005) * feat: BED-5146 - support toString and toInt * feat: BED-5130 - support instant type components epochseconds and epocmillis for datetimes * fix: BED-5149, BED-5148 - implement cypher size function and rework type inference handling * fix: BED-5138 - apply dangling constraints to projection select for all shortest path queries * fix: BED-5130 - use numeric type casts to ensure that no loss of precision takes place * fix: BED-5154 - switch entity references in translation to int8[] and index values to int (#1006) --------- Co-authored-by: Arianna Cooper --- .../src/api/v2/datapipe_integration_test.go | 4 +- .../database/migration/migrations/v6.3.0.sql | 3 + packages/go/cypher/models/cypher/functions.go | 17 ++ .../go/cypher/models/pgsql/format/format.go | 27 +++- packages/go/cypher/models/pgsql/functions.go | 2 + .../go/cypher/models/pgsql/identifiers.go | 15 ++ packages/go/cypher/models/pgsql/model.go | 23 ++- packages/go/cypher/models/pgsql/pgtypes.go | 82 +++++++++- .../pgsql/test/translation_cases/nodes.sql | 65 +++++++- .../translation_cases/pattern_binding.sql | 16 +- .../translation_cases/pattern_expansion.sql | 28 ++-- .../test/translation_cases/shortest_paths.sql | 25 ++- .../models/pgsql/translate/expansion.go | 17 +- .../models/pgsql/translate/expression.go | 150 ++++++++++++------ .../models/pgsql/translate/expression_test.go | 2 +- .../models/pgsql/translate/projection.go | 2 +- .../models/pgsql/translate/translation.go | 70 ++++++++ .../models/pgsql/translate/translator.go | 85 +++++----- packages/go/cypher/models/walk/walk_cypher.go | 8 +- packages/go/cypher/models/walk/walk_pgsql.go | 10 ++ .../dawgs/drivers/pg/query/sql/schema_up.sql | 1 - packages/go/dawgs/drivers/pg/types.go | 73 ++++++++- 22 files changed, 593 insertions(+), 132 deletions(-) diff --git a/cmd/api/src/api/v2/datapipe_integration_test.go b/cmd/api/src/api/v2/datapipe_integration_test.go index e992a312ae..8caf3500a4 100644 --- a/cmd/api/src/api/v2/datapipe_integration_test.go +++ b/cmd/api/src/api/v2/datapipe_integration_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// SPDXLicenseIdentifier: Apache2.0 +// SPDX-License-Identifier: Apache-2.0 //go:build serial_integration // +build serial_integration diff --git a/cmd/api/src/database/migration/migrations/v6.3.0.sql b/cmd/api/src/database/migration/migrations/v6.3.0.sql index b21d0632cc..c44a919ced 100644 --- a/cmd/api/src/database/migration/migrations/v6.3.0.sql +++ b/cmd/api/src/database/migration/migrations/v6.3.0.sql @@ -33,3 +33,6 @@ UPDATE feature_flags SET enabled = true WHERE key = 'updated_posture_page'; -- Fix users in bad state due to sso bug DELETE FROM auth_secrets WHERE id IN (SELECT auth_secrets.id FROM auth_secrets JOIN users ON users.id = auth_secrets.user_id WHERE users.sso_provider_id IS NOT NULL); + +-- Set the `oidc_support` feature flag to true +UPDATE feature_flags SET enabled = true WHERE key = 'oidc_support'; \ No newline at end of file diff --git a/packages/go/cypher/models/cypher/functions.go b/packages/go/cypher/models/cypher/functions.go index 1ffed87f99..c6dce37f2b 100644 --- a/packages/go/cypher/models/cypher/functions.go +++ b/packages/go/cypher/models/cypher/functions.go @@ -30,4 +30,21 @@ const ( NodeLabelsFunction = "labels" EdgeTypeFunction = "type" StringSplitToArrayFunction = "split" + ToStringFunction = "tostring" + ToIntegerFunction = "toint" + ListSizeFunction = "size" + + // ITTC - Instant Type; Temporal Component (https://neo4j.com/docs/cypher-manual/current/functions/temporal/) + ITTCYear = "year" + ITTCMonth = "month" + ITTCDay = "day" + ITTCHour = "hour" + ITTCMinute = "minute" + ITTCSecond = "second" + ITTCMillisecond = "millisecond" + ITTCMicrosecond = "microsecond" + ITTCNanosecond = "nanosecond" + ITTCTimeZone = "timezone" + ITTCEpochSeconds = "epochseconds" + ITTCEpochMilliseconds = "epochmillis" ) diff --git a/packages/go/cypher/models/pgsql/format/format.go b/packages/go/cypher/models/pgsql/format/format.go index ac2caeaf21..4280aef044 100644 --- a/packages/go/cypher/models/pgsql/format/format.go +++ b/packages/go/cypher/models/pgsql/format/format.go @@ -137,6 +137,12 @@ func formatValue(builder *OutputBuilder, value any) error { case bool: builder.Write(strconv.FormatBool(typedValue)) + case float32: + builder.Write(strconv.FormatFloat(float64(typedValue), 'f', -1, 64)) + + case float64: + builder.Write(strconv.FormatFloat(typedValue, 'f', -1, 64)) + default: return fmt.Errorf("unsupported literal type: %T", value) } @@ -482,8 +488,27 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error { exprStack = append(exprStack, pgsql.FormattingLiteral("not ")) } + case pgsql.ProjectionFrom: + for idx, projection := range typedNextExpr.Projection { + if idx > 0 { + builder.Write(", ") + } + + if err := formatNode(builder, projection); err != nil { + return err + } + } + + if len(typedNextExpr.From) > 0 { + builder.Write(" from ") + + if err := formatFromClauses(builder, typedNextExpr.From); err != nil { + return err + } + } + default: - return fmt.Errorf("unsupported node type: %T", nextExpr) + return fmt.Errorf("unable to format pgsql node type: %T", nextExpr) } } diff --git a/packages/go/cypher/models/pgsql/functions.go b/packages/go/cypher/models/pgsql/functions.go index 92bb29f247..be5267764a 100644 --- a/packages/go/cypher/models/pgsql/functions.go +++ b/packages/go/cypher/models/pgsql/functions.go @@ -24,6 +24,7 @@ const ( FunctionJSONBToTextArray Identifier = "jsonb_to_text_array" FunctionJSONBArrayElementsText Identifier = "jsonb_array_elements_text" FunctionJSONBBuildObject Identifier = "jsonb_build_object" + FunctionJSONBArrayLength Identifier = "jsonb_array_length" FunctionArrayLength Identifier = "array_length" FunctionArrayAggregate Identifier = "array_agg" FunctionMin Identifier = "min" @@ -41,4 +42,5 @@ const ( FunctionCount Identifier = "count" FunctionStringToArray Identifier = "string_to_array" FunctionEdgesToPath Identifier = "edges_to_path" + FunctionExtract Identifier = "extract" ) diff --git a/packages/go/cypher/models/pgsql/identifiers.go b/packages/go/cypher/models/pgsql/identifiers.go index 4e19ab092a..582290bd93 100644 --- a/packages/go/cypher/models/pgsql/identifiers.go +++ b/packages/go/cypher/models/pgsql/identifiers.go @@ -25,8 +25,23 @@ import ( const ( WildcardIdentifier Identifier = "*" + EpochIdentifier Identifier = "epoch" ) +var reservedIdentifiers = []Identifier{ + EpochIdentifier, +} + +func IsReservedIdentifier(identifier Identifier) bool { + for _, reservedIdentifier := range reservedIdentifiers { + if identifier == reservedIdentifier { + return true + } + } + + return false +} + func AsOptionalIdentifier(identifier Identifier) models.Optional[Identifier] { return models.ValueOptional(identifier) } diff --git a/packages/go/cypher/models/pgsql/model.go b/packages/go/cypher/models/pgsql/model.go index acc97590b4..c890a855e0 100644 --- a/packages/go/cypher/models/pgsql/model.go +++ b/packages/go/cypher/models/pgsql/model.go @@ -403,9 +403,17 @@ type AnyExpression struct { } func NewAnyExpression(inner Expression) AnyExpression { - return AnyExpression{ + newAnyExpression := AnyExpression{ Expression: inner, } + + // This is a guard to prevent recursive wrapping of an expression in an Any expression + switch innerTypeHint := inner.(type) { + case TypeHinted: + newAnyExpression.CastType = innerTypeHint.TypeHint() + } + + return newAnyExpression } func (s AnyExpression) AsExpression() Expression { @@ -972,6 +980,19 @@ func (s Projection) NodeType() string { return "projection" } +type ProjectionFrom struct { + Projection Projection + From []FromClause +} + +func (s ProjectionFrom) NodeType() string { + return "projection from" +} + +func (s ProjectionFrom) AsExpression() Expression { + return s +} + // Select is a SQL expression that is evaluated to fetch data. type Select struct { Distinct bool diff --git a/packages/go/cypher/models/pgsql/pgtypes.go b/packages/go/cypher/models/pgsql/pgtypes.go index de4a12ba6f..8839e37971 100644 --- a/packages/go/cypher/models/pgsql/pgtypes.go +++ b/packages/go/cypher/models/pgsql/pgtypes.go @@ -93,6 +93,8 @@ const ( TextArray DataType = "text[]" JSONB DataType = "jsonb" JSONBArray DataType = "jsonb[]" + Numeric DataType = "numeric" + NumericArray DataType = "numeric[]" Date DataType = "date" TimeWithTimeZone DataType = "time with time zone" TimeWithoutTimeZone DataType = "time without time zone" @@ -109,7 +111,8 @@ const ( ExpansionTerminalNode DataType = "expansion_terminal_node" ) -func (s DataType) Convert(other DataType) (DataType, bool) { +// TODO: operator, while unused, is part of a refactor for this function to make it operator aware +func (s DataType) Compatible(other DataType, operator Operator) (DataType, bool) { if s == other { return s, true } @@ -132,6 +135,12 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Float8: return Float8, true + case Float4Array: + return Float4, true + + case Float8Array: + return Float8, true + case Text: return Text, true } @@ -141,6 +150,21 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Float4: return Float8, true + case Float4Array, Float8Array: + return Float8, true + + case Text: + return Text, true + } + + case Numeric: + switch other { + case Float4, Float8, Int2, Int4, Int8: + return Numeric, true + + case Float4Array, Float8Array, NumericArray: + return Numeric, true + case Text: return Text, true } @@ -156,6 +180,15 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Int8: return Int8, true + case Int2Array: + return Int2, true + + case Int4Array: + return Int4, true + + case Int8Array: + return Int8, true + case Text: return Text, true } @@ -168,6 +201,12 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Int8: return Int8, true + case Int2Array, Int4Array: + return Int4, true + + case Int8Array: + return Int8, true + case Text: return Text, true } @@ -177,9 +216,42 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Int2, Int4, Int8: return Int8, true + case Int2Array, Int4Array, Int8Array: + return Int8, true + + case Text: + return Text, true + } + + case Int: + switch other { + case Int2, Int4, Int: + return Int, true + + case Int8: + return Int8, true + case Text: return Text, true } + + case Int2Array: + switch other { + case Int2Array, Int4Array, Int8Array: + return other, true + } + + case Int4Array: + switch other { + case Int4Array, Int8Array: + return other, true + } + + case Float4Array: + switch other { + case Float4Array, Float8Array: + return other, true + } } return UnsetDataType, false @@ -207,7 +279,7 @@ func (s DataType) MatchesOneOf(others ...DataType) bool { func (s DataType) IsArrayType() bool { switch s { - case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray: + case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray, JSONBArray, NodeCompositeArray, EdgeCompositeArray, NumericArray: return true } @@ -239,6 +311,8 @@ func (s DataType) ToArrayType() (DataType, error) { return Float8Array, nil case Text, TextArray: return TextArray, nil + case Numeric, NumericArray: + return NumericArray, nil default: return UnknownDataType, ErrNoAvailableArrayDataType } @@ -258,8 +332,10 @@ func (s DataType) ArrayBaseType() (DataType, error) { return Float8, nil case TextArray: return Text, nil + case NumericArray: + return Numeric, nil default: - return UnknownDataType, ErrNonArrayDataType + return s, nil } } diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql index 8c18b2ed0b..69b09a8b98 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql @@ -82,7 +82,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, node n1 - where (s0.n0).id = any (jsonb_to_text_array(n1.properties -> 'captured_ids')::int4[])) + where (s0.n0).id = any (jsonb_to_text_array(n1.properties -> 'captured_ids')::int8[])) select s1.n0 as s, s1.n1 as e from s1; @@ -498,3 +498,66 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 where n0.properties ->> 'system_tags' like '%' || ('text')::text) select s0.n0 as n from s0; + +-- case: match (n:NodeKind1) where toString(n.functionallevel) in ['2008 R2','2012','2008','2003','2003 Interim','2000 Mixed/Native'] return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'functionallevel')::text = any + (array ['2008 R2', '2012', '2008', '2003', '2003 Interim', '2000 Mixed/Native']::text[])) +select s0.n0 as n +from s0; + +-- case: match (n:NodeKind1) where toInt(n.value) in [1, 2, 3, 4] return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'value')::int8 = any (array [1, 2, 3, 4]::int8[])) +select s0.n0 as n +from s0; + +-- case: match (u:NodeKind1) where u.pwdlastset < (datetime().epochseconds - (365 * 86400)) and not u.pwdlastset IN [-1.0, 0.0] return u limit 100 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'pwdlastset')::numeric < + (extract(epoch from now()::timestamp with time zone)::numeric - (365 * 86400)) + and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) +select s0.n0 as u +from s0 +limit 100; + +-- case: match (u:NodeKind1) where u.pwdlastset < (datetime().epochmillis - (365 * 86400000)) and not u.pwdlastset IN [-1.0, 0.0] return u limit 100 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'pwdlastset')::numeric < + (extract(epoch from now()::timestamp with time zone)::numeric * 1000 - (365 * 86400000)) + and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) +select s0.n0 as u +from s0 +limit 100; + +-- case: match (n:NodeKind1) where size(n.array_value) > 0 return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and jsonb_array_length(n0.properties -> 'array_value')::int > 0) +select s0.n0 as n +from s0; + +-- case: match (n) where 1 in n.array return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where 1 = any (jsonb_to_text_array(n0.properties -> 'array')::int8[])) +select s0.n0 as n +from s0; + +-- case: match (n) where $p in n.array or $f in n.array return n +-- cypher_params: {"p": 1, "f": "text"} +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where @pi0::float8 = any (jsonb_to_text_array(n0.properties -> 'array')::float8[]) + or @pi1::text = any (jsonb_to_text_array(n0.properties -> 'array')::text[])) +select s0.n0 as n +from s0; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql index 2c85c05023..4cb3465067 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql @@ -21,7 +21,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id) -select edges_to_path(variadic array [(s0.e0).id]::int4[])::pathcomposite as p +select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p from s0; -- case: match p = ()-[r1]->()-[r2]->(e) return e @@ -92,7 +92,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite edge e1 join node n2 on (n2.properties -> 'is_target')::bool and n2.id = e1.start_id where (s0.n1).id = e1.end_id) -select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int4[])::pathcomposite as p +select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int8[])::pathcomposite as p from s1; -- case: match p = ()-[*..]->() return p limit 1 @@ -126,7 +126,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id) select edges_to_path(variadic ep0)::pathcomposite as p from s0 limit 1; @@ -162,7 +162,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied), s1 as (select s0.e0 as e0, s0.ep0 as ep0, @@ -174,7 +174,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat edge e1 join node n2 on n2.id = e1.end_id where (s0.n1).id = e1.start_id) -select edges_to_path(variadic array [(s1.e1).id]::int4[] || s1.ep0)::pathcomposite as p +select edges_to_path(variadic array [(s1.e1).id]::int8[] || s1.ep0)::pathcomposite as p from s1 limit 1; @@ -220,8 +220,8 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite ex0 join edge e1 on e1.id = any (ex0.path) join node n1 on n1.id = ex0.root_id - join node n2 on e1.id = ex0.path[array_length(ex0.path, 1)::int4] and n2.id = e1.end_id) -select s1.e0 as e, edges_to_path(variadic array [(s1.e0).id]::int4[] || s1.ep0)::pathcomposite as p + join node n2 on e1.id = ex0.path[array_length(ex0.path, 1)::int] and n2.id = e1.end_id) +select s1.e0 as e, edges_to_path(variadic array [(s1.e0).id]::int8[] || s1.ep0)::pathcomposite as p from s1; -- case: match p = (m:NodeKind1)-[:EdgeKind1]->(c:NodeKind2) where m.objectid ends with "-513" and not toUpper(c.operatingsystem) contains "SERVER" return p limit 1000 @@ -235,6 +235,6 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite not upper(n1.properties ->> 'operatingsystem')::text like '%SERVER%' and n1.id = e0.end_id where e0.kind_id = any (array [11]::int2[])) -select edges_to_path(variadic array [(s0.e0).id]::int4[])::pathcomposite as p +select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p from s0 limit 1000; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql index e49a4e0e84..87da96d391 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql @@ -45,7 +45,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id) select s0.n0 as n, s0.n1 as e from s0; @@ -80,7 +80,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select s0.n1 as e from s0; @@ -119,7 +119,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select s0.n1 as e from s0; @@ -155,7 +155,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select s0.n0 as n from s0; @@ -191,7 +191,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied), s1 as (select s0.e0 as e0, s0.ep0 as ep0, @@ -237,7 +237,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied), s1 as (select s0.e0 as e0, s0.ep0 as ep0, @@ -286,7 +286,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat ex1 join edge e2 on e2.id = any (ex1.path) join node n2 on n2.id = ex1.root_id - join node n3 on e2.id = ex1.path[array_length(ex1.path, 1)::int4] and n3.id = e2.end_id) + join node n3 on e2.id = ex1.path[array_length(ex1.path, 1)::int] and n3.id = e2.end_id) select s2.n3 as l from s2; @@ -332,7 +332,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0 @@ -372,7 +372,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied and n0.id <> n1.id) select edges_to_path(variadic ep0)::pathcomposite as p @@ -422,7 +422,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0; @@ -463,7 +463,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.start_id) select edges_to_path(variadic ep0)::pathcomposite as p from s0 limit 10; @@ -505,7 +505,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.start_id where ex0.satisfied), s1 as (with recursive ex1(root_id, next_id, depth, satisfied, is_cycle, path) as (select e1.start_id, e1.end_id, @@ -544,7 +544,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat ex1 join edge e1 on e1.id = any (ex1.path) join node n1 on n1.id = ex1.root_id - join node n2 on e1.id = ex1.path[array_length(ex1.path, 1)::int4] and n2.id = e1.start_id) + join node n2 on e1.id = ex1.path[array_length(ex1.path, 1)::int] and n2.id = e1.start_id) select edges_to_path(variadic s1.ep1 || s1.ep0)::pathcomposite as p from s1 limit 10; @@ -593,7 +593,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0 diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql index 5544c30586..b7035887f1 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql @@ -28,7 +28,7 @@ with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id) select edges_to_path(variadic ep0)::pathcomposite as p from s0; @@ -46,7 +46,28 @@ with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0; + +-- case: match p=shortestPath((n:NodeKind1)-[:EdgeKind1*1..]->(m)) where 'admin_tier_0' in split(m.system_tags, ' ') and n.objectid ends with '-513' and n<>m return p limit 1000 +-- cypher_params: {} +-- pgsql_params: {"pi0":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select e0.start_id, e0.end_id, 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.start_id = e0.end_id, array [e0.id] from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.properties ->> 'objectid' like '%-513' and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [11]::int2[]);", "pi1":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select ex0.root_id, e0.end_id, ex0.depth + 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.id = any (ex0.path), ex0.path || e0.id from pathspace ex0 join edge e0 on e0.start_id = ex0.next_id join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle and e0.kind_id = any (array [11]::int2[]);"} +with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) + as (select * from asp_harness(@pi0::text, @pi1::text, 5)) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id + where ex0.satisfied + and n0.id <> n1.id) +select edges_to_path(variadic ep0)::pathcomposite as p +from s0 +limit 1000; diff --git a/packages/go/cypher/models/pgsql/translate/expansion.go b/packages/go/cypher/models/pgsql/translate/expansion.go index 872b66c2bb..e819d7cb07 100644 --- a/packages/go/cypher/models/pgsql/translate/expansion.go +++ b/packages/go/cypher/models/pgsql/translate/expansion.go @@ -320,7 +320,7 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath}, pgsql.NewLiteral(1, pgsql.Int8), }, - CastType: pgsql.Int4, + CastType: pgsql.Int, }, }, }, @@ -417,8 +417,15 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave ), } - // Make sure to only accept paths that are satisfied - expansion.ProjectionStatement.Where = pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionSatisfied} + // Constraints that target the terminal node may crop up here where it's finally in scope. Additionally, + // only accept paths that are marked satisfied from the recursive descent CTE + if constraints, err := consumeConstraintsFrom(traversalStep.Expansion.Value.Frame.Visible, s.treeTranslator.IdentifierConstraints); err != nil { + return pgsql.Query{}, err + } else if projectionConstraints, err := ConjoinExpressions([]pgsql.Expression{pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionSatisfied}, constraints.Expression}); err != nil { + return pgsql.Query{}, err + } else { + expansion.ProjectionStatement.Where = projectionConstraints + } } } else { expansion.PrimerStatement.Projection = []pgsql.SelectItem{ @@ -653,7 +660,7 @@ func (s *Translator) buildExpansionPatternRoot(part *PatternPart, traversalStep pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath}, pgsql.NewLiteral(1, pgsql.Int8), }, - CastType: pgsql.Int4, + CastType: pgsql.Int, }, }, }, @@ -940,7 +947,7 @@ func (s *Translator) buildExpansionPatternStep(part *PatternPart, traversalStep pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath}, pgsql.NewLiteral(1, pgsql.Int8), }, - CastType: pgsql.Int4, + CastType: pgsql.Int, }, }, }, diff --git a/packages/go/cypher/models/pgsql/translate/expression.go b/packages/go/cypher/models/pgsql/translate/expression.go index 1da57810e6..1fdd24a364 100644 --- a/packages/go/cypher/models/pgsql/translate/expression.go +++ b/packages/go/cypher/models/pgsql/translate/expression.go @@ -31,8 +31,15 @@ type PropertyLookup struct { } func asPropertyLookup(expression pgsql.Expression) (*pgsql.BinaryExpression, bool) { - if binaryExpression, isBinaryExpression := expression.(*pgsql.BinaryExpression); isBinaryExpression { - return binaryExpression, pgsql.OperatorIsPropertyLookup(binaryExpression.Operator) + switch typedExpression := expression.(type) { + case pgsql.AnyExpression: + // This is here to unwrap Any expressions that have been passed in as a property lookup. This is + // common when dealing with array operators. In the future this check should be handled by the + // caller to simplify the logic here. + return asPropertyLookup(typedExpression.Expression) + + case *pgsql.BinaryExpression: + return typedExpression, pgsql.OperatorIsPropertyLookup(typedExpression.Operator) } return nil, false @@ -64,10 +71,17 @@ func ExtractSyntaxNodeReferences(root pgsql.SyntaxNode) (*pgsql.IdentifierSet, e func(node pgsql.SyntaxNode, errorHandler walk.CancelableErrorHandler) { switch typedNode := node.(type) { case pgsql.Identifier: - dependencies.Add(typedNode) + // Filter for reserved identifiers + if !pgsql.IsReservedIdentifier(typedNode) { + dependencies.Add(typedNode) + } case pgsql.CompoundIdentifier: - dependencies.Add(typedNode.Root()) + identifier := typedNode.Root() + + if !pgsql.IsReservedIdentifier(identifier) { + dependencies.Add(identifier) + } } }, )) @@ -83,7 +97,6 @@ func applyUnaryExpressionTypeHints(expression *pgsql.UnaryExpression) error { func rewritePropertyLookupOperator(propertyLookup *pgsql.BinaryExpression, dataType pgsql.DataType) pgsql.Expression { if dataType.IsArrayType() { - // This property lookup needs to be coerced into an array type using a function return pgsql.FunctionCall{ Function: pgsql.FunctionJSONBToTextArray, Parameters: []pgsql.Expression{propertyLookup}, @@ -126,7 +139,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy if isLeftHinted { if isRightHinted { - if higherLevelHint, matchesOrConverts := leftHint.Convert(rightHint); !matchesOrConverts { + if higherLevelHint, matchesOrConverts := leftHint.Compatible(rightHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, rightHint) } else { return higherLevelHint, nil @@ -136,44 +149,52 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy } else if inferredRightHint == pgsql.UnknownDataType { // Assume the right side is convertable and return the left operand hint return leftHint, nil - } else if upcastHint, matchesOrConverts := leftHint.Convert(inferredRightHint); !matchesOrConverts { + } else if upcastHint, matchesOrConverts := leftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredRightHint) } else { return upcastHint, nil } } else if isRightHinted { // There's no left type, attempt to infer it - if inferredLeftHint, err := InferExpressionType(expression.ROperand); err != nil { + if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil { return pgsql.UnsetDataType, err } else if inferredLeftHint == pgsql.UnknownDataType { // Assume the right side is convertable and return the left operand hint return rightHint, nil - } else if upcastHint, matchesOrConverts := rightHint.Convert(inferredLeftHint); !matchesOrConverts { - return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredLeftHint) + } else if upcastHint, matchesOrConverts := rightHint.Compatible(inferredLeftHint, expression.Operator); !matchesOrConverts { + return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, rightHint, inferredLeftHint) } else { return upcastHint, nil } - } else if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil { - return pgsql.UnsetDataType, err - } else if inferredRightHint, err := InferExpressionType(expression.ROperand); err != nil { - return pgsql.UnsetDataType, err - } else if inferredLeftHint == pgsql.UnknownDataType && inferredRightHint == pgsql.UnknownDataType { - // If neither side has type information then check the operator to see if it implies some type hinting + } else { + // If neither side has specific type information then check the operator to see if it implies some type + // hinting before resorting to inference switch expression.Operator { case pgsql.OperatorStartsWith, pgsql.OperatorContains, pgsql.OperatorEndsWith: // String operations imply the operands must be text return pgsql.Text, nil - // TODO: Boolean inference for OperatorAnd and OperatorOr may want to be plumbed here + case pgsql.OperatorAnd, pgsql.OperatorOr: + // Boolean operators that the operands must be boolean + return pgsql.Boolean, nil default: - // Unable to infer any type information - return pgsql.UnknownDataType, nil + // The operator does not imply specific type information onto the operands. Attempt to infer any + // information as a last ditch effort to type the AST nodes + if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil { + return pgsql.UnsetDataType, err + } else if inferredRightHint, err := InferExpressionType(expression.ROperand); err != nil { + return pgsql.UnsetDataType, err + } else if inferredLeftHint == pgsql.UnknownDataType && inferredRightHint == pgsql.UnknownDataType { + // Unable to infer any type information, this may be resolved elsewhere so this is not explicitly + // an error condition + return pgsql.UnknownDataType, nil + } else if higherLevelHint, matchesOrConverts := inferredLeftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts { + return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, inferredLeftHint, inferredRightHint) + } else { + return higherLevelHint, nil + } } - } else if higherLevelHint, matchesOrConverts := inferredLeftHint.Convert(inferredRightHint); !matchesOrConverts { - return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredLeftHint) - } else { - return higherLevelHint, nil } } @@ -191,7 +212,7 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { // Infer type information for well known column names switch typedExpression[1] { case pgsql.ColumnGraphID, pgsql.ColumnID, pgsql.ColumnStartID, pgsql.ColumnEndID: - return pgsql.Int4, nil + return pgsql.Int8, nil case pgsql.ColumnKindID: return pgsql.Int2, nil @@ -215,13 +236,18 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { // This is unknown, not unset meaning that it can be re-cast by future inference inspections return pgsql.UnknownDataType, nil - case pgsql.OperatorAnd, pgsql.OperatorOr: + case pgsql.OperatorAnd, pgsql.OperatorOr, pgsql.OperatorEquals, pgsql.OperatorGreaterThan, pgsql.OperatorGreaterThanOrEqualTo, + pgsql.OperatorLessThan, pgsql.OperatorLessThanOrEqualTo, pgsql.OperatorIn, pgsql.OperatorJSONBFieldExists, + pgsql.OperatorLike, pgsql.OperatorILike, pgsql.OperatorPGArrayOverlap: return pgsql.Boolean, nil default: return inferBinaryExpressionType(typedExpression) } + case pgsql.Parenthetical: + return InferExpressionType(typedExpression.Expression) + default: log.Infof("unable to infer type hint for expression type: %T", expression) return pgsql.UnknownDataType, nil @@ -263,31 +289,65 @@ func TypeCastExpression(expression pgsql.Expression, dataType pgsql.DataType) (p return pgsql.NewTypeCast(expression, dataType), nil } -func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression, expressionTypeHint pgsql.DataType) error { - if leftPropertyLookup, isPropertyLookup := asPropertyLookup(expression.LOperand); isPropertyLookup { - if lookupRequiresElementType(expressionTypeHint, expression.Operator, expression.ROperand) { - // Take the base type of the array type hint: in - if arrayBaseType, err := expressionTypeHint.ArrayBaseType(); err != nil { - return err - } else { - expressionTypeHint = arrayBaseType - } - } +func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { + var ( + leftPropertyLookup, hasLeftPropertyLookup = asPropertyLookup(expression.LOperand) + rightPropertyLookup, hasRightPropertyLookup = asPropertyLookup(expression.ROperand) + ) - expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, expressionTypeHint) + // Don't rewrite direct property comparisons + if hasLeftPropertyLookup && hasRightPropertyLookup { + return nil } - if rightPropertyLookup, isPropertyLookup := asPropertyLookup(expression.ROperand); isPropertyLookup { - if lookupRequiresElementType(expressionTypeHint, expression.Operator, expression.LOperand) { - // Take the base type of the array type hint: in - if arrayBaseType, err := expressionTypeHint.ArrayBaseType(); err != nil { + if hasLeftPropertyLookup { + // This check exists here to prevent from overwriting a property lookup that's part of a in + // binary expression. This may want for better ergonomics in the future + if anyExpression, isAnyExpression := expression.ROperand.(pgsql.AnyExpression); isAnyExpression { + if arrayBaseType, err := anyExpression.CastType.ArrayBaseType(); err != nil { return err } else { - expressionTypeHint = arrayBaseType + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType) + } + } else if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil { + return err + } else { + switch expression.Operator { + case pgsql.OperatorIn: + if arrayBaseType, err := rOperandTypeHint.ArrayBaseType(); err != nil { + return err + } else { + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType) + } + + case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch: + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, pgsql.Text) + + default: + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, rOperandTypeHint) } } + } + + if hasRightPropertyLookup { + if lOperandTypeHint, err := InferExpressionType(expression.LOperand); err != nil { + return err + } else { + switch expression.Operator { + case pgsql.OperatorIn: + if arrayType, err := lOperandTypeHint.ToArrayType(); err != nil { + return err + } else { + expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, arrayType) + } + + case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch: + expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, pgsql.Text) - expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, expressionTypeHint) + default: + expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, lOperandTypeHint) + } + } } return nil @@ -301,11 +361,7 @@ func applyBinaryExpressionTypeHints(expression *pgsql.BinaryExpression) error { return nil } - if expressionTypeHint, err := InferExpressionType(expression); err != nil { - return err - } else { - return rewritePropertyLookupOperands(expression, expressionTypeHint) - } + return rewritePropertyLookupOperands(expression) } type Builder struct { diff --git a/packages/go/cypher/models/pgsql/translate/expression_test.go b/packages/go/cypher/models/pgsql/translate/expression_test.go index 802c9eab75..1e514ce38e 100644 --- a/packages/go/cypher/models/pgsql/translate/expression_test.go +++ b/packages/go/cypher/models/pgsql/translate/expression_test.go @@ -72,7 +72,7 @@ func TestInferExpressionType(t *testing.T) { ), ), }, { - ExpectedType: pgsql.Text, + ExpectedType: pgsql.Boolean, Expression: pgsql.NewBinaryExpression( mustAsLiteral("123"), pgsql.OperatorIn, diff --git a/packages/go/cypher/models/pgsql/translate/projection.go b/packages/go/cypher/models/pgsql/translate/projection.go index 6bf21f4920..5b3259dfbd 100644 --- a/packages/go/cypher/models/pgsql/translate/projection.go +++ b/packages/go/cypher/models/pgsql/translate/projection.go @@ -197,7 +197,7 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope * pgsql.OperatorConcatenate, pgsql.ArrayLiteral{ Values: edgeReferences, - CastType: pgsql.Int4Array, + CastType: pgsql.Int8Array, }, ) } diff --git a/packages/go/cypher/models/pgsql/translate/translation.go b/packages/go/cypher/models/pgsql/translate/translation.go index 03eb58cd20..5bc71308a9 100644 --- a/packages/go/cypher/models/pgsql/translate/translation.go +++ b/packages/go/cypher/models/pgsql/translate/translation.go @@ -60,6 +60,76 @@ func (s *Translator) translateRemoveItem(removeItem *cypher.RemoveItem) error { return nil } +func (s *Translator) translatePropertyLookup(lookup *cypher.PropertyLookup) { + if translatedAtom, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + switch typedTranslatedAtom := translatedAtom.(type) { + case pgsql.Identifier: + if fieldIdentifierLiteral, err := pgsql.AsLiteral(lookup.Symbols[0]); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.CompoundIdentifier{typedTranslatedAtom, pgsql.ColumnProperties}) + s.treeTranslator.Push(fieldIdentifierLiteral) + + if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorPropertyLookup); err != nil { + s.SetError(err) + } + } + + case pgsql.FunctionCall: + if fieldIdentifierLiteral, err := pgsql.AsLiteral(lookup.Symbols[0]); err != nil { + s.SetError(err) + } else if componentName, typeOK := fieldIdentifierLiteral.Value.(string); !typeOK { + s.SetErrorf("expected a string component name in translated literal but received type: %T", fieldIdentifierLiteral.Value) + } else { + switch typedTranslatedAtom.Function { + case pgsql.FunctionCurrentDate, pgsql.FunctionLocalTime, pgsql.FunctionCurrentTime, pgsql.FunctionLocalTimestamp, pgsql.FunctionNow: + switch componentName { + case cypher.ITTCEpochSeconds: + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionExtract, + Parameters: []pgsql.Expression{pgsql.ProjectionFrom{ + Projection: []pgsql.SelectItem{ + pgsql.EpochIdentifier, + }, + From: []pgsql.FromClause{{ + Source: translatedAtom, + }}, + }}, + CastType: pgsql.Numeric, + }) + + case cypher.ITTCEpochMilliseconds: + s.treeTranslator.Push(pgsql.NewBinaryExpression( + pgsql.FunctionCall{ + Function: pgsql.FunctionExtract, + Parameters: []pgsql.Expression{pgsql.ProjectionFrom{ + Projection: []pgsql.SelectItem{ + pgsql.EpochIdentifier, + }, + From: []pgsql.FromClause{{ + Source: translatedAtom, + }}, + }}, + CastType: pgsql.Numeric, + }, + pgsql.OperatorMultiply, + pgsql.NewLiteral(1000, pgsql.Int4), + )) + + default: + s.SetErrorf("unsupported date time instant type component %s from function call %s", componentName, typedTranslatedAtom.Function) + } + + default: + s.SetErrorf("unsupported instant type component %s from function call %s", componentName, typedTranslatedAtom.Function) + } + } + } + } +} + func (s *Translator) translateSetItem(setItem *cypher.SetItem) error { if operator, err := translateCypherAssignmentOperator(setItem.Operator); err != nil { return err diff --git a/packages/go/cypher/models/pgsql/translate/translator.go b/packages/go/cypher/models/pgsql/translate/translator.go index 68b057f5f6..02fa92daca 100644 --- a/packages/go/cypher/models/pgsql/translate/translator.go +++ b/packages/go/cypher/models/pgsql/translate/translator.go @@ -128,7 +128,7 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { case *cypher.RegularQuery, *cypher.SingleQuery, *cypher.PatternElement, *cypher.Return, *cypher.Comparison, *cypher.Skip, *cypher.Limit, cypher.Operator, *cypher.ArithmeticExpression, *cypher.NodePattern, *cypher.RelationshipPattern, *cypher.Remove, *cypher.Set, - *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression: + *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression, *cypher.PropertyLookup: // No operation for these syntax nodes case *cypher.Negation: @@ -259,27 +259,6 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { case *cypher.FunctionInvocation: s.pushState(StateTranslatingNestedExpression) - case *cypher.PropertyLookup: - if variable, isVariable := typedExpression.Atom.(*cypher.Variable); !isVariable { - s.SetErrorf("expected variable for property lookup reference but found type: %T", typedExpression.Atom) - } else if resolved, isResolved := s.query.Scope.LookupString(variable.Symbol); !isResolved { - s.SetErrorf("unable to resolve identifier: %s", variable.Symbol) - } else { - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - // TODO: Cypher does not support nested property references so the Symbols slice should be a string - if fieldIdentifierLiteral, err := pgsql.AsLiteral(typedExpression.Symbols[0]); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(pgsql.CompoundIdentifier{resolved.Identifier, pgsql.ColumnProperties}) - s.treeTranslator.Push(fieldIdentifierLiteral) - } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) - } - } - case *cypher.Order: s.pushState(StateTranslatingOrderBy) @@ -610,6 +589,31 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { }) } + case cypher.ListSizeFunction: + if typedExpression.NumArguments() > 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + var functionCall pgsql.FunctionCall + + if _, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + functionCall = pgsql.FunctionCall{ + Function: pgsql.FunctionJSONBArrayLength, + Parameters: []pgsql.Expression{argument}, + CastType: pgsql.Int, + } + } else { + functionCall = pgsql.FunctionCall{ + Function: pgsql.FunctionArrayLength, + Parameters: []pgsql.Expression{argument, pgsql.NewLiteral(1, pgsql.Int)}, + CastType: pgsql.Int, + } + } + + s.treeTranslator.Push(functionCall) + } + case cypher.ToUpperFunction: if typedExpression.NumArguments() > 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) @@ -628,6 +632,24 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { }) } + case cypher.ToStringFunction: + if typedExpression.NumArguments() > 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Text)) + } + + case cypher.ToIntegerFunction: + if typedExpression.NumArguments() > 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Int8)) + } + default: s.SetErrorf("unknown cypher function: %s", typedExpression.Name) } @@ -715,24 +737,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case *cypher.PropertyLookup: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorPropertyLookup); err != nil { - s.SetError(err) - } - - case StateTranslatingProjection: - if nextExpression, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else if selectItem, isProjection := nextExpression.(pgsql.SelectItem); !isProjection { - s.SetErrorf("invalid type for select item: %T", nextExpression) - } else { - s.projections.CurrentProjection().SelectItem = selectItem - } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) - } + s.translatePropertyLookup(typedExpression) case *cypher.PartialComparison: switch currentState := s.currentState(); currentState { diff --git a/packages/go/cypher/models/walk/walk_cypher.go b/packages/go/cypher/models/walk/walk_cypher.go index ee7d0a581a..f7f5a35c6d 100644 --- a/packages/go/cypher/models/walk/walk_cypher.go +++ b/packages/go/cypher/models/walk/walk_cypher.go @@ -36,12 +36,18 @@ func cypherSyntaxNodeSliceTypeConvert[F any, FS []F](fs FS) ([]cypher.SyntaxNode func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], error) { switch typedNode := node.(type) { // Types with no AST branches - case *cypher.RangeQuantifier, *cypher.PropertyLookup, cypher.Operator, *cypher.KindMatcher, + case *cypher.RangeQuantifier, cypher.Operator, *cypher.KindMatcher, *cypher.Limit, *cypher.Skip, graph.Kinds, *cypher.Parameter: return &Cursor[cypher.SyntaxNode]{ Node: node, }, nil + case *cypher.PropertyLookup: + return &Cursor[cypher.SyntaxNode]{ + Node: node, + Branches: []cypher.SyntaxNode{typedNode.Atom}, + }, nil + case *cypher.MapItem: return &Cursor[cypher.SyntaxNode]{ Node: node, diff --git a/packages/go/cypher/models/walk/walk_pgsql.go b/packages/go/cypher/models/walk/walk_pgsql.go index 65df9c7cce..29924ae140 100644 --- a/packages/go/cypher/models/walk/walk_pgsql.go +++ b/packages/go/cypher/models/walk/walk_pgsql.go @@ -321,6 +321,16 @@ func newSQLWalkCursor(node pgsql.SyntaxNode) (*Cursor[pgsql.SyntaxNode], error) Branches: []pgsql.SyntaxNode{typedNode.Subquery}, }, nil + case pgsql.ProjectionFrom: + if branches, err := pgsqlSyntaxNodeSliceTypeConvert(typedNode.From); err != nil { + return nil, err + } else { + return &Cursor[pgsql.SyntaxNode]{ + Node: node, + Branches: append([]pgsql.SyntaxNode{typedNode.Projection}, branches...), + }, nil + } + default: return nil, fmt.Errorf("unable to negotiate sql type %T into a translation cursor", node) } diff --git a/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql index 3ba6038445..3876967d2c 100644 --- a/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql +++ b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql @@ -190,7 +190,6 @@ execute procedure delete_node_edges(); alter table edge alter column properties set storage main; - -- Index on the graph ID of each edge. create index if not exists edge_graph_id_index on edge using btree (graph_id); diff --git a/packages/go/dawgs/drivers/pg/types.go b/packages/go/dawgs/drivers/pg/types.go index a0305b84af..3ce3ae9f01 100644 --- a/packages/go/dawgs/drivers/pg/types.go +++ b/packages/go/dawgs/drivers/pg/types.go @@ -60,11 +60,76 @@ func castMapValueAsSliceOf[T any](compositeMap map[string]any, key string) ([]T, func castAndAssignMapValue[T any](compositeMap map[string]any, key string, dst *T) error { if src, hasKey := compositeMap[key]; !hasKey { return fmt.Errorf("composite map does not contain expected key %s", key) - } else if typed, typeOK := src.(T); !typeOK { - var empty T - return fmt.Errorf("expected type %T but received %T", empty, src) } else { - *dst = typed + switch typedSrc := src.(type) { + case int8: + switch typedDst := any(dst).(type) { + case *int8: + *typedDst = typedSrc + case *int16: + *typedDst = int16(typedSrc) + case *int32: + *typedDst = int32(typedSrc) + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int16: + switch typedDst := any(dst).(type) { + case *int16: + *typedDst = typedSrc + case *int32: + *typedDst = int32(typedSrc) + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int32: + switch typedDst := any(dst).(type) { + case *int32: + *typedDst = typedSrc + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int64: + switch typedDst := any(dst).(type) { + case *int64: + *typedDst = typedSrc + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int: + switch typedDst := any(dst).(type) { + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = typedSrc + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case T: + *dst = typedSrc + + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } } return nil