diff --git a/cmd/api/src/analysis/ad/adcs_integration_test.go b/cmd/api/src/analysis/ad/adcs_integration_test.go index 39e9bd9213..5a85fe3ed2 100644 --- a/cmd/api/src/analysis/ad/adcs_integration_test.go +++ b/cmd/api/src/analysis/ad/adcs_integration_test.go @@ -447,63 +447,75 @@ func TestTrustedForNTAuth(t *testing.T) { func TestEnrollOnBehalfOf(t *testing.T) { testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { - harness.EnrollOnBehalfOfHarnessOne.Setup(testContext) + harness.EnrollOnBehalfOfHarness1.Setup(testContext) return nil }, func(harness integration.HarnessDetails, db graph.Database) { certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate) v1Templates := make([]*graph.Node, 0) + v2Templates := make([]*graph.Node, 0) + for _, template := range certTemplates { if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil { continue } else if version == 1 { v1Templates = append(v1Templates, template) } else if version >= 2 { - continue + v2Templates = append(v2Templates, template) } } require.Nil(t, err) db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { - results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates) + results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1) require.Nil(t, err) require.Len(t, results, 3) require.Contains(t, results, analysis.CreatePostRelationshipJob{ - FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate11.ID, - ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID, + FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate11.ID, + ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID, Kind: ad.EnrollOnBehalfOf, }) require.Contains(t, results, analysis.CreatePostRelationshipJob{ - FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate13.ID, - ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID, + FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate13.ID, + ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID, Kind: ad.EnrollOnBehalfOf, }) require.Contains(t, results, analysis.CreatePostRelationshipJob{ - FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID, - ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID, + FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID, + ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID, Kind: ad.EnrollOnBehalfOf, }) return nil }) + + db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { + results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1) + require.Nil(t, err) + + require.Len(t, results, 0) + + return nil + }) }) testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { - harness.EnrollOnBehalfOfHarnessTwo.Setup(testContext) + harness.EnrollOnBehalfOfHarness2.Setup(testContext) return nil }, func(harness integration.HarnessDetails, db graph.Database) { certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate) + v1Templates := make([]*graph.Node, 0) v2Templates := make([]*graph.Node, 0) for _, template := range certTemplates { if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil { continue } else if version == 1 { - continue + v1Templates = append(v1Templates, template) } else if version >= 2 { v2Templates = append(v2Templates, template) } @@ -512,15 +524,60 @@ func TestEnrollOnBehalfOf(t *testing.T) { require.Nil(t, err) db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { - results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates) + results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2) + require.Nil(t, err) + + require.Len(t, results, 0) + return nil + }) + + db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { + results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2) require.Nil(t, err) require.Len(t, results, 1) require.Contains(t, results, analysis.CreatePostRelationshipJob{ - FromID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate21.ID, - ToID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate23.ID, + FromID: harness.EnrollOnBehalfOfHarness2.CertTemplate21.ID, + ToID: harness.EnrollOnBehalfOfHarness2.CertTemplate23.ID, Kind: ad.EnrollOnBehalfOf, }) + return nil + }) + }) + + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { + harness.EnrollOnBehalfOfHarness3.Setup(testContext) + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { + operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - EnrollOnBehalfOf 3") + + _, enterpriseCertAuthorities, certTemplates, domains, cache, err := FetchADCSPrereqs(db) + require.Nil(t, err) + + if err := ad2.PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates, cache, operation); err != nil { + t.Logf("failed post processing for %s: %v", ad.EnrollOnBehalfOf.String(), err) + } + err = operation.Done() + require.Nil(t, err) + + db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { + if startNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria { + return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf) + })); err != nil { + t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err) + } else if endNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria { + return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf) + })); err != nil { + t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err) + } else { + require.Len(t, startNodes, 2) + require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate11)) + require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12)) + + require.Len(t, endNodes, 2) + require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12)) + require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12)) + } return nil }) diff --git a/cmd/api/src/api/v2/datapipe_integration_test.go b/cmd/api/src/api/v2/datapipe_integration_test.go index e992a312ae..8caf3500a4 100644 --- a/cmd/api/src/api/v2/datapipe_integration_test.go +++ b/cmd/api/src/api/v2/datapipe_integration_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// SPDXLicenseIdentifier: Apache2.0 +// SPDX-License-Identifier: Apache-2.0 //go:build serial_integration // +build serial_integration diff --git a/cmd/api/src/database/migration/migrations/v6.3.0.sql b/cmd/api/src/database/migration/migrations/v6.3.0.sql index b21d0632cc..c44a919ced 100644 --- a/cmd/api/src/database/migration/migrations/v6.3.0.sql +++ b/cmd/api/src/database/migration/migrations/v6.3.0.sql @@ -33,3 +33,6 @@ UPDATE feature_flags SET enabled = true WHERE key = 'updated_posture_page'; -- Fix users in bad state due to sso bug DELETE FROM auth_secrets WHERE id IN (SELECT auth_secrets.id FROM auth_secrets JOIN users ON users.id = auth_secrets.user_id WHERE users.sso_provider_id IS NOT NULL); + +-- Set the `oidc_support` feature flag to true +UPDATE feature_flags SET enabled = true WHERE key = 'oidc_support'; \ No newline at end of file diff --git a/cmd/api/src/test/integration/harnesses.go b/cmd/api/src/test/integration/harnesses.go index 637aebc7c0..dede7f384a 100644 --- a/cmd/api/src/test/integration/harnesses.go +++ b/cmd/api/src/test/integration/harnesses.go @@ -1595,7 +1595,7 @@ func (s *ADCSESC1HarnessAuthUsers) Setup(graphTestContext *GraphTestContext) { graphTestContext.UpdateNode(s.AuthUsers) } -type EnrollOnBehalfOfHarnessTwo struct { +type EnrollOnBehalfOfHarness2 struct { Domain2 *graph.Node AuthStore2 *graph.Node RootCA2 *graph.Node @@ -1604,10 +1604,9 @@ type EnrollOnBehalfOfHarnessTwo struct { CertTemplate22 *graph.Node CertTemplate23 *graph.Node CertTemplate24 *graph.Node - CertTemplate25 *graph.Node } -func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) { +func (s *EnrollOnBehalfOfHarness2) Setup(gt *GraphTestContext) { certRequestAgentEKU := make([]string, 0) certRequestAgentEKU = append(certRequestAgentEKU, adAnalysis.EkuCertRequestAgent) emptyAppPolicies := make([]string, 0) @@ -1623,7 +1622,7 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) { SubjectAltRequireUPN: false, SubjectAltRequireSPN: false, NoSecurityExtension: false, - SchemaVersion: 1, + SchemaVersion: 2, AuthorizedSignatures: 0, EffectiveEKUs: certRequestAgentEKU, ApplicationPolicies: emptyAppPolicies, @@ -1635,7 +1634,7 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) { SubjectAltRequireUPN: false, SubjectAltRequireSPN: false, NoSecurityExtension: false, - SchemaVersion: 1, + SchemaVersion: 2, AuthorizedSignatures: 0, EffectiveEKUs: []string{adAnalysis.EkuCertRequestAgent, adAnalysis.EkuAnyPurpose}, ApplicationPolicies: emptyAppPolicies, @@ -1664,18 +1663,6 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) { EffectiveEKUs: emptyAppPolicies, ApplicationPolicies: emptyAppPolicies, }) - s.CertTemplate25 = gt.NewActiveDirectoryCertTemplate("certtemplate2-5", sid, CertTemplateData{ - RequiresManagerApproval: false, - AuthenticationEnabled: false, - EnrolleeSuppliesSubject: false, - SubjectAltRequireUPN: false, - SubjectAltRequireSPN: false, - NoSecurityExtension: false, - SchemaVersion: 1, - AuthorizedSignatures: 1, - EffectiveEKUs: emptyAppPolicies, - ApplicationPolicies: emptyAppPolicies, - }) gt.NewRelationship(s.AuthStore2, s.Domain2, ad.NTAuthStoreFor) gt.NewRelationship(s.RootCA2, s.Domain2, ad.RootCAFor) @@ -1685,10 +1672,9 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) { gt.NewRelationship(s.CertTemplate22, s.EnterpriseCA2, ad.PublishedTo) gt.NewRelationship(s.CertTemplate23, s.EnterpriseCA2, ad.PublishedTo) gt.NewRelationship(s.CertTemplate24, s.EnterpriseCA2, ad.PublishedTo) - gt.NewRelationship(s.CertTemplate25, s.EnterpriseCA2, ad.PublishedTo) } -type EnrollOnBehalfOfHarnessOne struct { +type EnrollOnBehalfOfHarness1 struct { Domain1 *graph.Node AuthStore1 *graph.Node RootCA1 *graph.Node @@ -1698,7 +1684,7 @@ type EnrollOnBehalfOfHarnessOne struct { CertTemplate13 *graph.Node } -func (s *EnrollOnBehalfOfHarnessOne) Setup(gt *GraphTestContext) { +func (s *EnrollOnBehalfOfHarness1) Setup(gt *GraphTestContext) { sid := RandomDomainSID() anyPurposeEkus := make([]string, 0) anyPurposeEkus = append(anyPurposeEkus, adAnalysis.EkuAnyPurpose) @@ -1753,6 +1739,74 @@ func (s *EnrollOnBehalfOfHarnessOne) Setup(gt *GraphTestContext) { gt.NewRelationship(s.CertTemplate13, s.EnterpriseCA1, ad.PublishedTo) } +type EnrollOnBehalfOfHarness3 struct { + Domain1 *graph.Node + AuthStore1 *graph.Node + RootCA1 *graph.Node + EnterpriseCA1 *graph.Node + EnterpriseCA2 *graph.Node + CertTemplate11 *graph.Node + CertTemplate12 *graph.Node + CertTemplate13 *graph.Node +} + +func (s *EnrollOnBehalfOfHarness3) Setup(gt *GraphTestContext) { + sid := RandomDomainSID() + anyPurposeEkus := make([]string, 0) + anyPurposeEkus = append(anyPurposeEkus, adAnalysis.EkuAnyPurpose) + emptyAppPolicies := make([]string, 0) + s.Domain1 = gt.NewActiveDirectoryDomain("domain1", sid, false, true) + s.AuthStore1 = gt.NewActiveDirectoryNTAuthStore("authstore1", sid) + s.RootCA1 = gt.NewActiveDirectoryRootCA("rca1", sid) + s.EnterpriseCA1 = gt.NewActiveDirectoryEnterpriseCA("eca1", sid) + s.EnterpriseCA2 = gt.NewActiveDirectoryEnterpriseCA("eca2", sid) + s.CertTemplate11 = gt.NewActiveDirectoryCertTemplate("certtemplate1-1", sid, CertTemplateData{ + RequiresManagerApproval: false, + AuthenticationEnabled: false, + EnrolleeSuppliesSubject: false, + SubjectAltRequireUPN: false, + SubjectAltRequireSPN: false, + NoSecurityExtension: false, + SchemaVersion: 2, + AuthorizedSignatures: 0, + EffectiveEKUs: anyPurposeEkus, + ApplicationPolicies: emptyAppPolicies, + }) + s.CertTemplate12 = gt.NewActiveDirectoryCertTemplate("certtemplate1-2", sid, CertTemplateData{ + RequiresManagerApproval: false, + AuthenticationEnabled: false, + EnrolleeSuppliesSubject: false, + SubjectAltRequireUPN: false, + SubjectAltRequireSPN: false, + NoSecurityExtension: false, + SchemaVersion: 1, + AuthorizedSignatures: 0, + EffectiveEKUs: anyPurposeEkus, + ApplicationPolicies: emptyAppPolicies, + }) + s.CertTemplate13 = gt.NewActiveDirectoryCertTemplate("certtemplate1-3", sid, CertTemplateData{ + RequiresManagerApproval: false, + AuthenticationEnabled: false, + EnrolleeSuppliesSubject: false, + SubjectAltRequireUPN: false, + SubjectAltRequireSPN: false, + NoSecurityExtension: false, + SchemaVersion: 2, + AuthorizedSignatures: 0, + EffectiveEKUs: anyPurposeEkus, + ApplicationPolicies: emptyAppPolicies, + }) + + gt.NewRelationship(s.AuthStore1, s.Domain1, ad.NTAuthStoreFor) + gt.NewRelationship(s.RootCA1, s.Domain1, ad.RootCAFor) + gt.NewRelationship(s.EnterpriseCA1, s.AuthStore1, ad.TrustedForNTAuth) + gt.NewRelationship(s.EnterpriseCA1, s.RootCA1, ad.EnterpriseCAFor) + gt.NewRelationship(s.EnterpriseCA2, s.RootCA1, ad.EnterpriseCAFor) + gt.NewRelationship(s.CertTemplate11, s.EnterpriseCA1, ad.PublishedTo) + gt.NewRelationship(s.CertTemplate12, s.EnterpriseCA1, ad.PublishedTo) + gt.NewRelationship(s.CertTemplate13, s.EnterpriseCA2, ad.PublishedTo) +} + type ADCSGoldenCertHarness struct { NTAuthStore1 *graph.Node RootCA1 *graph.Node @@ -8437,8 +8491,9 @@ type HarnessDetails struct { ShortcutHarnessEveryone2 ShortcutHarnessEveryone2 ADCSESC1Harness ADCSESC1Harness ADCSESC1HarnessAuthUsers ADCSESC1HarnessAuthUsers - EnrollOnBehalfOfHarnessOne EnrollOnBehalfOfHarnessOne - EnrollOnBehalfOfHarnessTwo EnrollOnBehalfOfHarnessTwo + EnrollOnBehalfOfHarness1 EnrollOnBehalfOfHarness1 + EnrollOnBehalfOfHarness2 EnrollOnBehalfOfHarness2 + EnrollOnBehalfOfHarness3 EnrollOnBehalfOfHarness3 ADCSGoldenCertHarness ADCSGoldenCertHarness IssuedSignedByHarness IssuedSignedByHarness EnterpriseCAForHarness EnterpriseCAForHarness diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json index cc51045b70..9232e6f121 100644 --- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json @@ -54,10 +54,10 @@ }, "nodes": [ { - "id": "n1", + "id": "n0", "position": { - "x": 729.9551990267428, - "y": -4 + "x": 675.9551990267428, + "y": 50 }, "caption": "Domain1", "labels": [], @@ -67,10 +67,10 @@ } }, { - "id": "n2", + "id": "n1", "position": { - "x": 129, - "y": 273.97628342478527 + "x": 75, + "y": 327.97628342478527 }, "caption": "CertTemplate1-1", "labels": [], @@ -83,10 +83,10 @@ } }, { - "id": "n3", + "id": "n2", "position": { - "x": 487.6313891898351, - "y": -4 + "x": 433.6313891898351, + "y": 50 }, "caption": "NTAuthStore1", "labels": [], @@ -96,10 +96,10 @@ } }, { - "id": "n4", + "id": "n3", "position": { - "x": 487.6313891898351, - "y": 273.97628342478527 + "x": 433.6313891898351, + "y": 327.97628342478527 }, "caption": "EnterpriseCA1", "labels": [], @@ -109,10 +109,10 @@ } }, { - "id": "n5", + "id": "n4", "position": { - "x": 230.03558347937087, - "y": 551.9525668495705 + "x": 176.03558347937087, + "y": 605.9525668495705 }, "caption": "CertTemplate1-2", "labels": [], @@ -125,10 +125,10 @@ } }, { - "id": "n6", + "id": "n5", "position": { - "x": 508.01086036954564, - "y": 551.2045130499298 + "x": 454.01086036954564, + "y": 605.2045130499298 }, "caption": "CertTemplate1-3", "labels": [], @@ -141,10 +141,10 @@ } }, { - "id": "n7", + "id": "n6", "position": { - "x": 729.9551990267428, - "y": 273.97628342478527 + "x": 675.9551990267428, + "y": 327.97628342478527 }, "caption": "RootCA1", "labels": [], @@ -157,64 +157,64 @@ "relationships": [ { "id": "n0", - "fromId": "n3", - "toId": "n1", + "fromId": "n2", + "toId": "n0", "type": "NTAuthStoreFor", "properties": {}, "style": {} }, { "id": "n1", - "fromId": "n4", - "toId": "n3", + "fromId": "n3", + "toId": "n2", "type": "TrustedForNTAuth", "properties": {}, "style": {} }, { "id": "n2", - "fromId": "n2", - "toId": "n4", + "fromId": "n1", + "toId": "n3", "type": "PublishedTo", "properties": {}, "style": {} }, { "id": "n3", - "fromId": "n5", - "toId": "n4", + "fromId": "n4", + "toId": "n3", "type": "PublishedTo", "properties": {}, "style": {} }, { "id": "n4", - "fromId": "n6", - "toId": "n4", + "fromId": "n5", + "toId": "n3", "type": "PublishedTo", "properties": {}, "style": {} }, { "id": "n5", - "fromId": "n7", - "toId": "n1", + "fromId": "n6", + "toId": "n0", "type": "RootCAFor", "properties": {}, "style": {} }, { "id": "n6", - "fromId": "n4", - "toId": "n7", + "fromId": "n3", + "toId": "n6", "type": "EnterpriseCAFor", "properties": {}, "style": {} }, { "id": "n7", - "fromId": "n2", - "toId": "n5", + "fromId": "n1", + "toId": "n4", "type": "EnrollOnBehalfOf", "properties": {}, "style": { @@ -223,13 +223,23 @@ }, { "id": "n8", - "fromId": "n6", - "toId": "n5", + "fromId": "n5", + "toId": "n4", "type": "EnrollOnBehalfOf", "properties": {}, "style": { "arrow-color": "#68ccca" } + }, + { + "id": "n9", + "type": "EnrollOnBehalfOf", + "style": { + "arrow-color": "#68ccca" + }, + "properties": {}, + "fromId": "n4", + "toId": "n4" } ] } \ No newline at end of file diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg index a86d80d2f0..6fcfa6838d 100644 --- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg @@ -1,18 +1 @@ - -NTAuthStoreForTrustedForNTAuthPublishedToPublishedToPublishedToRootCAForEnterpriseCAForEnrollOnBehalfOfEnrollOnBehalfOfDomain1CertTemplate1-1schemaversion:2ekus:["2.5.29.37.0"]NTAuthStore1EnterpriseCA1CertTemplate1-2schemaversion:1ekus:["2.5.29.37.0"]CertTemplate1-3schemaversion:2ekus:["2.5.29.37.0"]RootCA1 +NTAuthStoreForTrustedForNTAuthPublishedToPublishedToPublishedToRootCAForEnterpriseCAForEnrollOnBehalfOfEnrollOnBehalfOfEnrollOnBehalfOfDomain1CertTemplate1-1schemaversion:2effectiveekus:["2.5.29.37.0"]NTAuthStore1EnterpriseCA1CertTemplate1-2schemaversion:1effectiveekus:["2.5.29.37.0"]CertTemplate1-3schemaversion:2effectiveekus:["2.5.29.37.0"]RootCA1 \ No newline at end of file diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json index ee3e2a4b43..607f78cbef 100644 --- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json @@ -1,10 +1,63 @@ { + "style": { + "font-family": "sans-serif", + "background-color": "#ffffff", + "background-image": "", + "background-size": "100%", + "node-color": "#ffffff", + "border-width": 4, + "border-color": "#000000", + "radius": 50, + "node-padding": 5, + "node-margin": 2, + "outside-position": "auto", + "node-icon-image": "", + "node-background-image": "", + "icon-position": "inside", + "icon-size": 64, + "caption-position": "inside", + "caption-max-width": 200, + "caption-color": "#000000", + "caption-font-size": 50, + "caption-font-weight": "normal", + "label-position": "inside", + "label-display": "pill", + "label-color": "#000000", + "label-background-color": "#ffffff", + "label-border-color": "#000000", + "label-border-width": 4, + "label-font-size": 40, + "label-padding": 5, + "label-margin": 4, + "directionality": "directed", + "detail-position": "inline", + "detail-orientation": "parallel", + "arrow-width": 5, + "arrow-color": "#000000", + "margin-start": 5, + "margin-end": 5, + "margin-peer": 20, + "attachment-start": "normal", + "attachment-end": "normal", + "relationship-icon-image": "", + "type-color": "#000000", + "type-background-color": "#ffffff", + "type-border-color": "#000000", + "type-border-width": 0, + "type-font-size": 16, + "type-padding": 5, + "property-position": "outside", + "property-alignment": "colon", + "property-color": "#000000", + "property-font-size": 16, + "property-font-weight": "normal" + }, "nodes": [ { "id": "n0", "position": { - "x": -569.1685177598522, - "y": -1021.0927494329366 + "x": 657.3454879680903, + "y": 50 }, "caption": "Domain2", "labels": [], @@ -14,10 +67,10 @@ } }, { - "id": "n2", + "id": "n1", "position": { - "x": -811.4923275967599, - "y": -1021.0927494329366 + "x": 415.0216781311826, + "y": 50 }, "caption": "NTAuthStore2", "labels": [], @@ -27,10 +80,10 @@ } }, { - "id": "n5", + "id": "n2", "position": { - "x": -811.4923275967599, - "y": -743.1164660081513 + "x": 415.0216781311826, + "y": 327.97628342478527 }, "caption": "EnterpriseCA2", "labels": [], @@ -40,10 +93,10 @@ } }, { - "id": "n13", + "id": "n3", "position": { - "x": -569.1685177598522, - "y": -743.1164660081513 + "x": 657.3454879680903, + "y": 327.97628342478527 }, "caption": "RootCA2", "labels": [], @@ -53,40 +106,42 @@ } }, { - "id": "n14", + "id": "n4", "position": { - "x": -1151.5140057279425, - "y": -743.1164660081513 + "x": 75, + "y": 327.97628342478527 }, "caption": "CertTemplate2-1", "labels": [], "properties": { - "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\"]" + "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\"]", + "schemaversion": "2" }, "style": { "node-color": "#fda1ff" } }, { - "id": "n15", + "id": "n5", "position": { - "x": -1151.5140057279425, - "y": -546.8048586088046 + "x": 75, + "y": 524.287890824132 }, "caption": "CertTemplate2-2", "labels": [], "properties": { - "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\", \"2.5.29.37.0\"]" + "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\", \"2.5.29.37.0\"]", + "schemaversion": "2" }, "style": { "node-color": "#fda1ff" } }, { - "id": "n16", + "id": "n6", "position": { - "x": -981.5031666623508, - "y": -448.6490549091318 + "x": 245.01083906559165, + "y": 622.4436945238048 }, "caption": "CertTemplate2-3", "labels": [], @@ -101,10 +156,10 @@ } }, { - "id": "n17", + "id": "n7", "position": { - "x": -695.4923275967599, - "y": -448.6490549091318 + "x": 531.0216781311826, + "y": 622.4436945238048 }, "caption": "CertTemplate2-4", "labels": [], @@ -117,29 +172,12 @@ "style": { "node-color": "#fda1ff" } - }, - { - "id": "n18", - "position": { - "x": -517.02491649774, - "y": -448.6490549091315 - }, - "caption": "CertTemplate2-5", - "labels": [], - "properties": { - "effectiveekus": "[]", - "schemaversion": "1", - "subjectaltrequiresupn": "true" - }, - "style": { - "node-color": "#fda1ff" - } } ], "relationships": [ { "id": "n0", - "fromId": "n2", + "fromId": "n1", "toId": "n0", "type": "NTAuthStoreFor", "properties": {}, @@ -147,141 +185,69 @@ }, { "id": "n1", - "fromId": "n5", - "toId": "n2", + "fromId": "n2", + "toId": "n1", "type": "TrustedForNTAuth", "properties": {}, "style": {} }, { - "id": "n9", - "fromId": "n13", + "id": "n2", + "fromId": "n3", "toId": "n0", "type": "RootCAFor", "properties": {}, "style": {} }, { - "id": "n10", - "fromId": "n5", - "toId": "n13", + "id": "n3", + "fromId": "n2", + "toId": "n3", "type": "EnterpriseCAFor", "properties": {}, "style": {} }, { - "id": "n11", - "fromId": "n14", - "toId": "n5", - "type": "PublishedTo", - "properties": {}, - "style": {} - }, - { - "id": "n12", - "fromId": "n15", - "toId": "n5", + "id": "n4", + "fromId": "n4", + "toId": "n2", "type": "PublishedTo", "properties": {}, "style": {} }, { - "id": "n13", - "fromId": "n16", - "toId": "n5", + "id": "n5", + "fromId": "n5", + "toId": "n2", "type": "PublishedTo", "properties": {}, "style": {} }, { - "id": "n14", - "fromId": "n17", - "toId": "n5", + "id": "n6", + "fromId": "n6", + "toId": "n2", "type": "PublishedTo", "properties": {}, "style": {} }, { - "id": "n15", - "type": "EnrollOnBehalfOf", - "style": { - "arrow-color": "#a4dd00" - }, - "properties": {}, - "toId": "n16", - "fromId": "n14" - }, - { - "id": "n16", - "fromId": "n18", - "toId": "n5", + "id": "n7", + "fromId": "n7", + "toId": "n2", "type": "PublishedTo", "properties": {}, "style": {} }, { - "id": "n17", - "fromId": "n18", - "toId": "n18", + "id": "n8", + "fromId": "n4", + "toId": "n6", "type": "EnrollOnBehalfOf", "properties": {}, "style": { - "type-color": "#4d4d4d", "arrow-color": "#a4dd00" } } - ], - "style": { - "font-family": "sans-serif", - "background-color": "#ffffff", - "background-image": "", - "background-size": "100%", - "node-color": "#ffffff", - "border-width": 4, - "border-color": "#000000", - "radius": 50, - "node-padding": 5, - "node-margin": 2, - "outside-position": "auto", - "node-icon-image": "", - "node-background-image": "", - "icon-position": "inside", - "icon-size": 64, - "caption-position": "inside", - "caption-max-width": 200, - "caption-color": "#000000", - "caption-font-size": 50, - "caption-font-weight": "normal", - "label-position": "inside", - "label-display": "pill", - "label-color": "#000000", - "label-background-color": "#ffffff", - "label-border-color": "#000000", - "label-border-width": 4, - "label-font-size": 40, - "label-padding": 5, - "label-margin": 4, - "directionality": "directed", - "detail-position": "inline", - "detail-orientation": "parallel", - "arrow-width": 5, - "arrow-color": "#000000", - "margin-start": 5, - "margin-end": 5, - "margin-peer": 20, - "attachment-start": "normal", - "attachment-end": "normal", - "relationship-icon-image": "", - "type-color": "#000000", - "type-background-color": "#ffffff", - "type-border-color": "#000000", - "type-border-width": 0, - "type-font-size": 16, - "type-padding": 5, - "property-position": "outside", - "property-alignment": "colon", - "property-color": "#000000", - "property-font-size": 16, - "property-font-weight": "normal" - } + ] } \ No newline at end of file diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg index dd203c5b22..d8d8ec0fcb 100644 --- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg @@ -1,18 +1 @@ - -NTAuthStoreForTrustedForNTAuthRootCAForEnterpriseCAForPublishedToPublishedToPublishedToPublishedToEnrollOnBehalfOfPublishedToEnrollOnBehalfOfDomain2NTAuthStore2EnterpriseCA2RootCA2CertTemplate2-1ekus:["1.3.6.1.4.1.311.20.2.1"]CertTemplate2-2ekus:["1.3.6.1.4.1.311.20.2.1", "2.5.29.37.0"]CertTemplate2-3ekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:["1.3.6.1.4.1.311.20.2.1"]CertTemplate2-4ekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:[]CertTemplate2-5ekus:[]schemaversion:1subjectaltrequiresupn:true +NTAuthStoreForTrustedForNTAuthRootCAForEnterpriseCAForPublishedToPublishedToPublishedToPublishedToEnrollOnBehalfOfDomain2NTAuthStore2EnterpriseCA2RootCA2CertTemplate2-1effectiveekus:["1.3.6.1.4.1.311.20.2.1"]schemaversion:2CertTemplate2-2effectiveekus:["1.3.6.1.4.1.311.20.2.1", "2.5.29.37.0"]schemaversion:2CertTemplate2-3effectiveekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:["1.3.6.1.4.1.311.20.2.1"]CertTemplate2-4effectiveekus:[]schemaversion:2authorizedsignatures:1applicationpolicies:[] \ No newline at end of file diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.json b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.json new file mode 100644 index 0000000000..e0537e9ec6 --- /dev/null +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.json @@ -0,0 +1,256 @@ +{ + "style": { + "font-family": "sans-serif", + "background-color": "#ffffff", + "background-image": "", + "background-size": "100%", + "node-color": "#ffffff", + "border-width": 4, + "border-color": "#000000", + "radius": 50, + "node-padding": 5, + "node-margin": 2, + "outside-position": "auto", + "node-icon-image": "", + "node-background-image": "", + "icon-position": "inside", + "icon-size": 64, + "caption-position": "inside", + "caption-max-width": 200, + "caption-color": "#000000", + "caption-font-size": 50, + "caption-font-weight": "normal", + "label-position": "inside", + "label-display": "pill", + "label-color": "#000000", + "label-background-color": "#ffffff", + "label-border-color": "#000000", + "label-border-width": 4, + "label-font-size": 40, + "label-padding": 5, + "label-margin": 4, + "directionality": "directed", + "detail-position": "inline", + "detail-orientation": "parallel", + "arrow-width": 5, + "arrow-color": "#000000", + "margin-start": 5, + "margin-end": 5, + "margin-peer": 20, + "attachment-start": "normal", + "attachment-end": "normal", + "relationship-icon-image": "", + "type-color": "#000000", + "type-background-color": "#ffffff", + "type-border-color": "#000000", + "type-border-width": 0, + "type-font-size": 16, + "type-padding": 5, + "property-position": "outside", + "property-alignment": "colon", + "property-color": "#000000", + "property-font-size": 16, + "property-font-weight": "normal" + }, + "nodes": [ + { + "id": "n0", + "position": { + "x": 675.9551990267428, + "y": 50 + }, + "caption": "Domain1", + "labels": [], + "properties": {}, + "style": { + "node-color": "#68ccca" + } + }, + { + "id": "n1", + "position": { + "x": -44.95912785436991, + "y": 327.97628342478527 + }, + "caption": "CertTemplate1-1", + "labels": [], + "properties": { + "schemaversion": "2", + "effectiveekus": "[\"2.5.29.37.0\"]" + }, + "style": { + "node-color": "#fda1ff" + } + }, + { + "id": "n2", + "position": { + "x": 433.6313891898351, + "y": 50 + }, + "caption": "NTAuthStore1", + "labels": [], + "properties": {}, + "style": { + "node-color": "#7b64ff" + } + }, + { + "id": "n3", + "position": { + "x": 433.6313891898351, + "y": 327.97628342478527 + }, + "caption": "EnterpriseCA1", + "labels": [], + "properties": {}, + "style": { + "node-color": "#b0bc00" + } + }, + { + "id": "n4", + "position": { + "x": 194.3361306677326, + "y": 557.3609078601627 + }, + "caption": "CertTemplate1-2", + "labels": [], + "properties": { + "schemaversion": "1", + "effectiveekus": "[\"2.5.29.37.0\"]" + }, + "style": { + "node-color": "#fda1ff" + } + }, + { + "id": "n5", + "position": { + "x": 433.6313891898351, + "y": 796.6561663822652 + }, + "caption": "CertTemplate1-3", + "labels": [], + "properties": { + "schemaversion": "2", + "effectiveekus": "[\"2.5.29.37.0\"]" + }, + "style": { + "node-color": "#fda1ff" + } + }, + { + "id": "n6", + "position": { + "x": 675.9551990267428, + "y": 327.97628342478527 + }, + "caption": "RootCA1", + "labels": [], + "properties": {}, + "style": { + "node-color": "#e27300" + } + }, + { + "id": "n7", + "position": { + "x": 433.6313891898351, + "y": 500.5114025285298 + }, + "caption": "EnterpriseCA2", + "style": { + "node-color": "#b0bc00" + }, + "labels": [], + "properties": {} + } + ], + "relationships": [ + { + "id": "n0", + "fromId": "n2", + "toId": "n0", + "type": "NTAuthStoreFor", + "properties": {}, + "style": {} + }, + { + "id": "n1", + "fromId": "n3", + "toId": "n2", + "type": "TrustedForNTAuth", + "properties": {}, + "style": {} + }, + { + "id": "n2", + "fromId": "n1", + "toId": "n3", + "type": "PublishedTo", + "properties": {}, + "style": {} + }, + { + "id": "n3", + "fromId": "n4", + "toId": "n3", + "type": "PublishedTo", + "properties": {}, + "style": {} + }, + { + "id": "n5", + "fromId": "n6", + "toId": "n0", + "type": "RootCAFor", + "properties": {}, + "style": {} + }, + { + "id": "n6", + "fromId": "n3", + "toId": "n6", + "type": "EnterpriseCAFor", + "properties": {}, + "style": {} + }, + { + "id": "n7", + "fromId": "n1", + "toId": "n4", + "type": "EnrollOnBehalfOf", + "properties": {}, + "style": { + "arrow-color": "#68ccca" + } + }, + { + "id": "n12", + "type": "PublishedTo", + "fromId": "n5", + "toId": "n7", + "style": {}, + "properties": {} + }, + { + "id": "n13", + "type": "EnterpriseCAFor", + "fromId": "n7", + "toId": "n6", + "style": {}, + "properties": {} + }, + { + "id": "n14", + "type": "EnrollOnBehalfOf", + "style": { + "arrow-color": "#68ccca" + }, + "properties": {}, + "fromId": "n4", + "toId": "n4" + } + ] +} \ No newline at end of file diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg new file mode 100644 index 0000000000..d7192c6074 --- /dev/null +++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg @@ -0,0 +1 @@ +NTAuthStoreForTrustedForNTAuthPublishedToPublishedToRootCAForEnterpriseCAForEnrollOnBehalfOfPublishedToEnterpriseCAForEnrollOnBehalfOfDomain1CertTemplate1-1schemaversion:2effectiveekus:["2.5.29.37.0"]NTAuthStore1EnterpriseCA1CertTemplate1-2schemaversion:1effectiveekus:["2.5.29.37.0"]CertTemplate1-3schemaversion:2effectiveekus:["2.5.29.37.0"]RootCA1EnterpriseCA2 \ No newline at end of file diff --git a/packages/go/analysis/ad/adcs.go b/packages/go/analysis/ad/adcs.go index 537b0b0770..ac1f9a76cb 100644 --- a/packages/go/analysis/ad/adcs.go +++ b/packages/go/analysis/ad/adcs.go @@ -47,30 +47,33 @@ func PostADCS(ctx context.Context, db graph.Database, groupExpansions impact.Pat return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed fetching domain nodes: %w", err) } else if step1Stats, err := postADCSPreProcessStep1(ctx, db, enterpriseCertAuthorities, rootCertAuthorities, aiaCertAuthorities, certTemplates); err != nil { return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed adcs pre-processing step 1: %w", err) - } else if step2Stats, err := postADCSPreProcessStep2(ctx, db, certTemplates); err != nil { - return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed adcs pre-processing step 2: %w", err) } else { - operation := analysis.NewPostRelationshipOperation(ctx, db, "ADCS Post Processing") - - operation.Stats.Merge(step1Stats) - operation.Stats.Merge(step2Stats) - var cache = NewADCSCache() cache.BuildCache(ctx, db, enterpriseCertAuthorities, certTemplates, domains) - for _, domain := range domains { - innerDomain := domain + if step2Stats, err := postADCSPreProcessStep2(ctx, db, domains, enterpriseCertAuthorities, certTemplates, cache); err != nil { + return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed adcs pre-processing step 2: %w", err) + } else { + + operation := analysis.NewPostRelationshipOperation(ctx, db, "ADCS Post Processing") + + operation.Stats.Merge(step1Stats) + operation.Stats.Merge(step2Stats) - for _, enterpriseCA := range enterpriseCertAuthorities { - innerEnterpriseCA := enterpriseCA + for _, domain := range domains { + innerDomain := domain - if cache.DoesCAChainProperlyToDomain(innerEnterpriseCA, innerDomain) { - processEnterpriseCAWithValidCertChainToDomain(innerEnterpriseCA, innerDomain, groupExpansions, cache, operation) + for _, enterpriseCA := range enterpriseCertAuthorities { + innerEnterpriseCA := enterpriseCA + + if cache.DoesCAChainProperlyToDomain(innerEnterpriseCA, innerDomain) { + processEnterpriseCAWithValidCertChainToDomain(innerEnterpriseCA, innerDomain, groupExpansions, cache, operation) + } } } - } + return &operation.Stats, operation.Done() - return &operation.Stats, operation.Done() + } } } @@ -97,10 +100,10 @@ func postADCSPreProcessStep1(ctx context.Context, db graph.Database, enterpriseC } // postADCSPreProcessStep2 Processes the edges that are dependent on those processed in postADCSPreProcessStep1 -func postADCSPreProcessStep2(ctx context.Context, db graph.Database, certTemplates []*graph.Node) (*analysis.AtomicPostProcessingStats, error) { +func postADCSPreProcessStep2(ctx context.Context, db graph.Database, domains, enterpriseCertAuthorities, certTemplates []*graph.Node, cache ADCSCache) (*analysis.AtomicPostProcessingStats, error) { operation := analysis.NewPostRelationshipOperation(ctx, db, "ADCS Post Processing Step 2") - if err := PostEnrollOnBehalfOf(certTemplates, operation); err != nil { + if err := PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates, cache, operation); err != nil { operation.Done() return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed post processing for %s: %w", ad.EnrollOnBehalfOf.String(), err) } else { diff --git a/packages/go/analysis/ad/esc3.go b/packages/go/analysis/ad/esc3.go index 2406e0081b..28abe1fcec 100644 --- a/packages/go/analysis/ad/esc3.go +++ b/packages/go/analysis/ad/esc3.go @@ -142,60 +142,71 @@ func PostADCSESC3(ctx context.Context, tx graph.Transaction, outC chan<- analysi return nil } -func PostEnrollOnBehalfOf(certTemplates []*graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) error { +func PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates []*graph.Node, cache ADCSCache, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) error { versionOneTemplates := make([]*graph.Node, 0) versionTwoTemplates := make([]*graph.Node, 0) - for _, node := range certTemplates { if version, err := node.Properties.Get(ad.SchemaVersion.String()).Float64(); errors.Is(err, graph.ErrPropertyNotFound) { log.Warnf("Did not get schema version for cert template %d: %v", node.ID, err) } else if err != nil { log.Errorf("Error getting schema version for cert template %d: %v", node.ID, err) + } else if version == 1 { + versionOneTemplates = append(versionOneTemplates, node) + } else if version >= 2 { + versionTwoTemplates = append(versionTwoTemplates, node) } else { - if version == 1 { - versionOneTemplates = append(versionOneTemplates, node) - } else if version >= 2 { - versionTwoTemplates = append(versionTwoTemplates, node) - } else { - log.Warnf("Got cert template %d with an invalid version %d", node.ID, version) - } + log.Warnf("Got cert template %d with an invalid version %d", node.ID, version) } } - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { - if results, err := EnrollOnBehalfOfVersionTwo(tx, versionTwoTemplates, certTemplates); err != nil { - return err - } else { - for _, result := range results { - if !channels.Submit(ctx, outC, result) { - return nil - } - } + for _, domain := range domains { + innerDomain := domain - return nil - } - }) + for _, enterpriseCA := range enterpriseCertAuthorities { + innerEnterpriseCA := enterpriseCA - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { - if results, err := EnrollOnBehalfOfVersionOne(tx, versionOneTemplates, certTemplates); err != nil { - return err - } else { - for _, result := range results { - if !channels.Submit(ctx, outC, result) { + if cache.DoesCAChainProperlyToDomain(innerEnterpriseCA, innerDomain) { + if publishedCertTemplates := cache.GetPublishedTemplateCache(enterpriseCA.ID); len(publishedCertTemplates) == 0 { return nil + } else { + operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + if results, err := EnrollOnBehalfOfVersionTwo(tx, versionTwoTemplates, publishedCertTemplates, innerDomain); err != nil { + return err + } else { + for _, result := range results { + if !channels.Submit(ctx, outC, result) { + return nil + } + } + + return nil + } + }) + + operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + if results, err := EnrollOnBehalfOfVersionOne(tx, versionOneTemplates, publishedCertTemplates, innerDomain); err != nil { + return err + } else { + for _, result := range results { + if !channels.Submit(ctx, outC, result) { + return nil + } + } + + return nil + } + }) } } - - return nil } - }) + } return nil } -func EnrollOnBehalfOfVersionTwo(tx graph.Transaction, versionTwoCertTemplates, allCertTemplates []*graph.Node) ([]analysis.CreatePostRelationshipJob, error) { +func EnrollOnBehalfOfVersionTwo(tx graph.Transaction, versionTwoCertTemplates, publishedTemplates []*graph.Node, domainNode *graph.Node) ([]analysis.CreatePostRelationshipJob, error) { results := make([]analysis.CreatePostRelationshipJob, 0) - for _, certTemplateOne := range allCertTemplates { + for _, certTemplateOne := range publishedTemplates { if hasBadEku, err := certTemplateHasEku(certTemplateOne, EkuAnyPurpose); errors.Is(err, graph.ErrPropertyNotFound) { log.Warnf("Did not get EffectiveEKUs for cert template %d: %v", certTemplateOne.ID, err) } else if err != nil { @@ -208,12 +219,6 @@ func EnrollOnBehalfOfVersionTwo(tx graph.Transaction, versionTwoCertTemplates, a log.Errorf("Error getting EffectiveEKUs for cert template %d: %v", certTemplateOne.ID, err) } else if !hasEku { continue - } else if domainNode, err := getDomainForCertTemplate(tx, certTemplateOne); err != nil { - log.Errorf("Error getting domain node for cert template %d: %v", certTemplateOne.ID, err) - } else if isLinked, err := DoesCertTemplateLinkToDomain(tx, certTemplateOne, domainNode); err != nil { - log.Errorf("Error fetching paths from cert template %d to domain: %v", certTemplateOne.ID, err) - } else if !isLinked { - continue } else { for _, certTemplateTwo := range versionTwoCertTemplates { if certTemplateOne.ID == certTemplateTwo.ID { @@ -260,10 +265,10 @@ func certTemplateHasEku(certTemplate *graph.Node, targetEkus ...string) (bool, e } } -func EnrollOnBehalfOfVersionOne(tx graph.Transaction, versionOneCertTemplates []*graph.Node, allCertTemplates []*graph.Node) ([]analysis.CreatePostRelationshipJob, error) { +func EnrollOnBehalfOfVersionOne(tx graph.Transaction, versionOneCertTemplates []*graph.Node, publishedTemplates []*graph.Node, domainNode *graph.Node) ([]analysis.CreatePostRelationshipJob, error) { results := make([]analysis.CreatePostRelationshipJob, 0) - for _, certTemplateOne := range allCertTemplates { + for _, certTemplateOne := range publishedTemplates { //prefilter as much as we can first if hasEku, err := certTemplateHasEkuOrAll(certTemplateOne, EkuCertRequestAgent, EkuAnyPurpose); errors.Is(err, graph.ErrPropertyNotFound) { log.Warnf("Error checking ekus for certtemplate %d: %v", certTemplateOne.ID, err) @@ -271,12 +276,6 @@ func EnrollOnBehalfOfVersionOne(tx graph.Transaction, versionOneCertTemplates [] log.Errorf("Error checking ekus for certtemplate %d: %v", certTemplateOne.ID, err) } else if !hasEku { continue - } else if domainNode, err := getDomainForCertTemplate(tx, certTemplateOne); err != nil { - log.Errorf("Error getting domain node for certtemplate %d: %v", certTemplateOne.ID, err) - } else if hasPath, err := DoesCertTemplateLinkToDomain(tx, certTemplateOne, domainNode); err != nil { - log.Errorf("Error fetching paths from certtemplate %d to domain: %v", certTemplateOne.ID, err) - } else if !hasPath { - continue } else { for _, certTemplateTwo := range versionOneCertTemplates { if hasPath, err := DoesCertTemplateLinkToDomain(tx, certTemplateTwo, domainNode); err != nil { @@ -359,19 +358,9 @@ func certTemplateHasEkuOrAll(certTemplate *graph.Node, targetEkus ...string) (bo } } -func getDomainForCertTemplate(tx graph.Transaction, certTemplate *graph.Node) (*graph.Node, error) { - if domainSid, err := certTemplate.Properties.Get(ad.DomainSID.String()).String(); err != nil { - return &graph.Node{}, err - } else if domainNode, err := analysis.FetchNodeByObjectID(tx, domainSid); err != nil { - return &graph.Node{}, err - } else { - return domainNode, nil - } -} - func GetADCSESC3EdgeComposition(ctx context.Context, db graph.Database, edge *graph.Relationship) (graph.PathSet, error) { /* - MATCH p1 = (x)-[:MemberOf*0..]->()-[:GenericAll|Enroll|AllExtendedRights]->(ct1:CertTemplate)-[:PublishedTo]->(eca1:EnterpriseCA) + MATCH p1 = (x)-[:MemberOf*0..]->()-[:GenericAll|Enroll|AllExtendedRights]->(ct1:CertTemplate)-[:PublishedTo]->(eca1:EnterpriseCA)-[:TrustedForNTAuth]->(:NTAuthStore)-[:NTAuthStoreFor]->(d) WHERE x.objectid = "S-1-5-21-83094068-830424655-2031507174-500" AND d.objectid = "S-1-5-21-83094068-830424655-2031507174" AND ct1.requiresmanagerapproval = false @@ -483,7 +472,7 @@ func GetADCSESC3EdgeComposition(ctx context.Context, db graph.Database, edge *gr for _, n := range startNodes.Slice() { if err := traversalInst.BreadthFirst(ctx, traversal.Plan{ Root: n, - Driver: ADCSESC3Path1Pattern(enterpriseCANodes).Do(func(terminal *graph.PathSegment) error { + Driver: ADCSESC3Path1Pattern(edge.EndID, enterpriseCANodes).Do(func(terminal *graph.PathSegment) error { certTemplateNode := terminal.Search(func(nextSegment *graph.PathSegment) bool { return nextSegment.Node.Kinds.ContainsOneOf(ad.CertTemplate) }) @@ -673,7 +662,7 @@ func GetADCSESC3EdgeComposition(ctx context.Context, db graph.Database, edge *gr return paths, nil } -func ADCSESC3Path1Pattern(enterpriseCAs cardinality.Duplex[uint64]) traversal.PatternContinuation { +func ADCSESC3Path1Pattern(domainId graph.ID, enterpriseCAs cardinality.Duplex[uint64]) traversal.PatternContinuation { return traversal.NewPattern().OutboundWithDepth(0, 0, query.And( query.Kind(query.Relationship(), ad.MemberOf), query.Kind(query.End(), ad.Group), @@ -696,6 +685,14 @@ func ADCSESC3Path1Pattern(enterpriseCAs cardinality.Duplex[uint64]) traversal.Pa query.KindIn(query.Relationship(), ad.PublishedTo), query.InIDs(query.End(), graph.DuplexToGraphIDs(enterpriseCAs)...), query.Kind(query.End(), ad.EnterpriseCA), + )). + Outbound(query.And( + query.KindIn(query.Relationship(), ad.TrustedForNTAuth), + query.Kind(query.End(), ad.NTAuthStore), + )). + Outbound(query.And( + query.KindIn(query.Relationship(), ad.NTAuthStoreFor), + query.Equals(query.EndID(), domainId), )) } diff --git a/packages/go/cypher/models/cypher/functions.go b/packages/go/cypher/models/cypher/functions.go index 1ffed87f99..c6dce37f2b 100644 --- a/packages/go/cypher/models/cypher/functions.go +++ b/packages/go/cypher/models/cypher/functions.go @@ -30,4 +30,21 @@ const ( NodeLabelsFunction = "labels" EdgeTypeFunction = "type" StringSplitToArrayFunction = "split" + ToStringFunction = "tostring" + ToIntegerFunction = "toint" + ListSizeFunction = "size" + + // ITTC - Instant Type; Temporal Component (https://neo4j.com/docs/cypher-manual/current/functions/temporal/) + ITTCYear = "year" + ITTCMonth = "month" + ITTCDay = "day" + ITTCHour = "hour" + ITTCMinute = "minute" + ITTCSecond = "second" + ITTCMillisecond = "millisecond" + ITTCMicrosecond = "microsecond" + ITTCNanosecond = "nanosecond" + ITTCTimeZone = "timezone" + ITTCEpochSeconds = "epochseconds" + ITTCEpochMilliseconds = "epochmillis" ) diff --git a/packages/go/cypher/models/pgsql/format/format.go b/packages/go/cypher/models/pgsql/format/format.go index ac2caeaf21..4280aef044 100644 --- a/packages/go/cypher/models/pgsql/format/format.go +++ b/packages/go/cypher/models/pgsql/format/format.go @@ -137,6 +137,12 @@ func formatValue(builder *OutputBuilder, value any) error { case bool: builder.Write(strconv.FormatBool(typedValue)) + case float32: + builder.Write(strconv.FormatFloat(float64(typedValue), 'f', -1, 64)) + + case float64: + builder.Write(strconv.FormatFloat(typedValue, 'f', -1, 64)) + default: return fmt.Errorf("unsupported literal type: %T", value) } @@ -482,8 +488,27 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error { exprStack = append(exprStack, pgsql.FormattingLiteral("not ")) } + case pgsql.ProjectionFrom: + for idx, projection := range typedNextExpr.Projection { + if idx > 0 { + builder.Write(", ") + } + + if err := formatNode(builder, projection); err != nil { + return err + } + } + + if len(typedNextExpr.From) > 0 { + builder.Write(" from ") + + if err := formatFromClauses(builder, typedNextExpr.From); err != nil { + return err + } + } + default: - return fmt.Errorf("unsupported node type: %T", nextExpr) + return fmt.Errorf("unable to format pgsql node type: %T", nextExpr) } } diff --git a/packages/go/cypher/models/pgsql/functions.go b/packages/go/cypher/models/pgsql/functions.go index 92bb29f247..be5267764a 100644 --- a/packages/go/cypher/models/pgsql/functions.go +++ b/packages/go/cypher/models/pgsql/functions.go @@ -24,6 +24,7 @@ const ( FunctionJSONBToTextArray Identifier = "jsonb_to_text_array" FunctionJSONBArrayElementsText Identifier = "jsonb_array_elements_text" FunctionJSONBBuildObject Identifier = "jsonb_build_object" + FunctionJSONBArrayLength Identifier = "jsonb_array_length" FunctionArrayLength Identifier = "array_length" FunctionArrayAggregate Identifier = "array_agg" FunctionMin Identifier = "min" @@ -41,4 +42,5 @@ const ( FunctionCount Identifier = "count" FunctionStringToArray Identifier = "string_to_array" FunctionEdgesToPath Identifier = "edges_to_path" + FunctionExtract Identifier = "extract" ) diff --git a/packages/go/cypher/models/pgsql/identifiers.go b/packages/go/cypher/models/pgsql/identifiers.go index 4e19ab092a..582290bd93 100644 --- a/packages/go/cypher/models/pgsql/identifiers.go +++ b/packages/go/cypher/models/pgsql/identifiers.go @@ -25,8 +25,23 @@ import ( const ( WildcardIdentifier Identifier = "*" + EpochIdentifier Identifier = "epoch" ) +var reservedIdentifiers = []Identifier{ + EpochIdentifier, +} + +func IsReservedIdentifier(identifier Identifier) bool { + for _, reservedIdentifier := range reservedIdentifiers { + if identifier == reservedIdentifier { + return true + } + } + + return false +} + func AsOptionalIdentifier(identifier Identifier) models.Optional[Identifier] { return models.ValueOptional(identifier) } diff --git a/packages/go/cypher/models/pgsql/model.go b/packages/go/cypher/models/pgsql/model.go index acc97590b4..c890a855e0 100644 --- a/packages/go/cypher/models/pgsql/model.go +++ b/packages/go/cypher/models/pgsql/model.go @@ -403,9 +403,17 @@ type AnyExpression struct { } func NewAnyExpression(inner Expression) AnyExpression { - return AnyExpression{ + newAnyExpression := AnyExpression{ Expression: inner, } + + // This is a guard to prevent recursive wrapping of an expression in an Any expression + switch innerTypeHint := inner.(type) { + case TypeHinted: + newAnyExpression.CastType = innerTypeHint.TypeHint() + } + + return newAnyExpression } func (s AnyExpression) AsExpression() Expression { @@ -972,6 +980,19 @@ func (s Projection) NodeType() string { return "projection" } +type ProjectionFrom struct { + Projection Projection + From []FromClause +} + +func (s ProjectionFrom) NodeType() string { + return "projection from" +} + +func (s ProjectionFrom) AsExpression() Expression { + return s +} + // Select is a SQL expression that is evaluated to fetch data. type Select struct { Distinct bool diff --git a/packages/go/cypher/models/pgsql/pgtypes.go b/packages/go/cypher/models/pgsql/pgtypes.go index de4a12ba6f..8839e37971 100644 --- a/packages/go/cypher/models/pgsql/pgtypes.go +++ b/packages/go/cypher/models/pgsql/pgtypes.go @@ -93,6 +93,8 @@ const ( TextArray DataType = "text[]" JSONB DataType = "jsonb" JSONBArray DataType = "jsonb[]" + Numeric DataType = "numeric" + NumericArray DataType = "numeric[]" Date DataType = "date" TimeWithTimeZone DataType = "time with time zone" TimeWithoutTimeZone DataType = "time without time zone" @@ -109,7 +111,8 @@ const ( ExpansionTerminalNode DataType = "expansion_terminal_node" ) -func (s DataType) Convert(other DataType) (DataType, bool) { +// TODO: operator, while unused, is part of a refactor for this function to make it operator aware +func (s DataType) Compatible(other DataType, operator Operator) (DataType, bool) { if s == other { return s, true } @@ -132,6 +135,12 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Float8: return Float8, true + case Float4Array: + return Float4, true + + case Float8Array: + return Float8, true + case Text: return Text, true } @@ -141,6 +150,21 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Float4: return Float8, true + case Float4Array, Float8Array: + return Float8, true + + case Text: + return Text, true + } + + case Numeric: + switch other { + case Float4, Float8, Int2, Int4, Int8: + return Numeric, true + + case Float4Array, Float8Array, NumericArray: + return Numeric, true + case Text: return Text, true } @@ -156,6 +180,15 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Int8: return Int8, true + case Int2Array: + return Int2, true + + case Int4Array: + return Int4, true + + case Int8Array: + return Int8, true + case Text: return Text, true } @@ -168,6 +201,12 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Int8: return Int8, true + case Int2Array, Int4Array: + return Int4, true + + case Int8Array: + return Int8, true + case Text: return Text, true } @@ -177,9 +216,42 @@ func (s DataType) Convert(other DataType) (DataType, bool) { case Int2, Int4, Int8: return Int8, true + case Int2Array, Int4Array, Int8Array: + return Int8, true + + case Text: + return Text, true + } + + case Int: + switch other { + case Int2, Int4, Int: + return Int, true + + case Int8: + return Int8, true + case Text: return Text, true } + + case Int2Array: + switch other { + case Int2Array, Int4Array, Int8Array: + return other, true + } + + case Int4Array: + switch other { + case Int4Array, Int8Array: + return other, true + } + + case Float4Array: + switch other { + case Float4Array, Float8Array: + return other, true + } } return UnsetDataType, false @@ -207,7 +279,7 @@ func (s DataType) MatchesOneOf(others ...DataType) bool { func (s DataType) IsArrayType() bool { switch s { - case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray: + case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray, JSONBArray, NodeCompositeArray, EdgeCompositeArray, NumericArray: return true } @@ -239,6 +311,8 @@ func (s DataType) ToArrayType() (DataType, error) { return Float8Array, nil case Text, TextArray: return TextArray, nil + case Numeric, NumericArray: + return NumericArray, nil default: return UnknownDataType, ErrNoAvailableArrayDataType } @@ -258,8 +332,10 @@ func (s DataType) ArrayBaseType() (DataType, error) { return Float8, nil case TextArray: return Text, nil + case NumericArray: + return Numeric, nil default: - return UnknownDataType, ErrNonArrayDataType + return s, nil } } diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql index 8c18b2ed0b..69b09a8b98 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql @@ -82,7 +82,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, node n1 - where (s0.n0).id = any (jsonb_to_text_array(n1.properties -> 'captured_ids')::int4[])) + where (s0.n0).id = any (jsonb_to_text_array(n1.properties -> 'captured_ids')::int8[])) select s1.n0 as s, s1.n1 as e from s1; @@ -498,3 +498,66 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 where n0.properties ->> 'system_tags' like '%' || ('text')::text) select s0.n0 as n from s0; + +-- case: match (n:NodeKind1) where toString(n.functionallevel) in ['2008 R2','2012','2008','2003','2003 Interim','2000 Mixed/Native'] return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'functionallevel')::text = any + (array ['2008 R2', '2012', '2008', '2003', '2003 Interim', '2000 Mixed/Native']::text[])) +select s0.n0 as n +from s0; + +-- case: match (n:NodeKind1) where toInt(n.value) in [1, 2, 3, 4] return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'value')::int8 = any (array [1, 2, 3, 4]::int8[])) +select s0.n0 as n +from s0; + +-- case: match (u:NodeKind1) where u.pwdlastset < (datetime().epochseconds - (365 * 86400)) and not u.pwdlastset IN [-1.0, 0.0] return u limit 100 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'pwdlastset')::numeric < + (extract(epoch from now()::timestamp with time zone)::numeric - (365 * 86400)) + and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) +select s0.n0 as u +from s0 +limit 100; + +-- case: match (u:NodeKind1) where u.pwdlastset < (datetime().epochmillis - (365 * 86400000)) and not u.pwdlastset IN [-1.0, 0.0] return u limit 100 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties -> 'pwdlastset')::numeric < + (extract(epoch from now()::timestamp with time zone)::numeric * 1000 - (365 * 86400000)) + and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) +select s0.n0 as u +from s0 +limit 100; + +-- case: match (n:NodeKind1) where size(n.array_value) > 0 return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and jsonb_array_length(n0.properties -> 'array_value')::int > 0) +select s0.n0 as n +from s0; + +-- case: match (n) where 1 in n.array return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where 1 = any (jsonb_to_text_array(n0.properties -> 'array')::int8[])) +select s0.n0 as n +from s0; + +-- case: match (n) where $p in n.array or $f in n.array return n +-- cypher_params: {"p": 1, "f": "text"} +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where @pi0::float8 = any (jsonb_to_text_array(n0.properties -> 'array')::float8[]) + or @pi1::text = any (jsonb_to_text_array(n0.properties -> 'array')::text[])) +select s0.n0 as n +from s0; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql index 2c85c05023..4cb3465067 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql @@ -21,7 +21,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id) -select edges_to_path(variadic array [(s0.e0).id]::int4[])::pathcomposite as p +select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p from s0; -- case: match p = ()-[r1]->()-[r2]->(e) return e @@ -92,7 +92,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite edge e1 join node n2 on (n2.properties -> 'is_target')::bool and n2.id = e1.start_id where (s0.n1).id = e1.end_id) -select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int4[])::pathcomposite as p +select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int8[])::pathcomposite as p from s1; -- case: match p = ()-[*..]->() return p limit 1 @@ -126,7 +126,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id) select edges_to_path(variadic ep0)::pathcomposite as p from s0 limit 1; @@ -162,7 +162,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied), s1 as (select s0.e0 as e0, s0.ep0 as ep0, @@ -174,7 +174,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat edge e1 join node n2 on n2.id = e1.end_id where (s0.n1).id = e1.start_id) -select edges_to_path(variadic array [(s1.e1).id]::int4[] || s1.ep0)::pathcomposite as p +select edges_to_path(variadic array [(s1.e1).id]::int8[] || s1.ep0)::pathcomposite as p from s1 limit 1; @@ -220,8 +220,8 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite ex0 join edge e1 on e1.id = any (ex0.path) join node n1 on n1.id = ex0.root_id - join node n2 on e1.id = ex0.path[array_length(ex0.path, 1)::int4] and n2.id = e1.end_id) -select s1.e0 as e, edges_to_path(variadic array [(s1.e0).id]::int4[] || s1.ep0)::pathcomposite as p + join node n2 on e1.id = ex0.path[array_length(ex0.path, 1)::int] and n2.id = e1.end_id) +select s1.e0 as e, edges_to_path(variadic array [(s1.e0).id]::int8[] || s1.ep0)::pathcomposite as p from s1; -- case: match p = (m:NodeKind1)-[:EdgeKind1]->(c:NodeKind2) where m.objectid ends with "-513" and not toUpper(c.operatingsystem) contains "SERVER" return p limit 1000 @@ -235,6 +235,6 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite not upper(n1.properties ->> 'operatingsystem')::text like '%SERVER%' and n1.id = e0.end_id where e0.kind_id = any (array [11]::int2[])) -select edges_to_path(variadic array [(s0.e0).id]::int4[])::pathcomposite as p +select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p from s0 limit 1000; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql index e49a4e0e84..87da96d391 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql @@ -45,7 +45,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id) select s0.n0 as n, s0.n1 as e from s0; @@ -80,7 +80,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select s0.n1 as e from s0; @@ -119,7 +119,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select s0.n1 as e from s0; @@ -155,7 +155,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select s0.n0 as n from s0; @@ -191,7 +191,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied), s1 as (select s0.e0 as e0, s0.ep0 as ep0, @@ -237,7 +237,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied), s1 as (select s0.e0 as e0, s0.ep0 as ep0, @@ -286,7 +286,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat ex1 join edge e2 on e2.id = any (ex1.path) join node n2 on n2.id = ex1.root_id - join node n3 on e2.id = ex1.path[array_length(ex1.path, 1)::int4] and n3.id = e2.end_id) + join node n3 on e2.id = ex1.path[array_length(ex1.path, 1)::int] and n3.id = e2.end_id) select s2.n3 as l from s2; @@ -332,7 +332,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0 @@ -372,7 +372,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied and n0.id <> n1.id) select edges_to_path(variadic ep0)::pathcomposite as p @@ -422,7 +422,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0; @@ -463,7 +463,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.start_id) select edges_to_path(variadic ep0)::pathcomposite as p from s0 limit 10; @@ -505,7 +505,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.start_id where ex0.satisfied), s1 as (with recursive ex1(root_id, next_id, depth, satisfied, is_cycle, path) as (select e1.start_id, e1.end_id, @@ -544,7 +544,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat ex1 join edge e1 on e1.id = any (ex1.path) join node n1 on n1.id = ex1.root_id - join node n2 on e1.id = ex1.path[array_length(ex1.path, 1)::int4] and n2.id = e1.start_id) + join node n2 on e1.id = ex1.path[array_length(ex1.path, 1)::int] and n2.id = e1.start_id) select edges_to_path(variadic s1.ep1 || s1.ep0)::pathcomposite as p from s1 limit 10; @@ -593,7 +593,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0 diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql index 5544c30586..b7035887f1 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql @@ -28,7 +28,7 @@ with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id) + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id) select edges_to_path(variadic ep0)::pathcomposite as p from s0; @@ -46,7 +46,28 @@ with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) from ex0 join edge e0 on e0.id = any (ex0.path) join node n0 on n0.id = ex0.root_id - join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0; + +-- case: match p=shortestPath((n:NodeKind1)-[:EdgeKind1*1..]->(m)) where 'admin_tier_0' in split(m.system_tags, ' ') and n.objectid ends with '-513' and n<>m return p limit 1000 +-- cypher_params: {} +-- pgsql_params: {"pi0":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select e0.start_id, e0.end_id, 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.start_id = e0.end_id, array [e0.id] from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.properties ->> 'objectid' like '%-513' and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [11]::int2[]);", "pi1":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select ex0.root_id, e0.end_id, ex0.depth + 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.id = any (ex0.path), ex0.path || e0.id from pathspace ex0 join edge e0 on e0.start_id = ex0.next_id join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle and e0.kind_id = any (array [11]::int2[]);"} +with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path) + as (select * from asp_harness(@pi0::text, @pi1::text, 5)) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id + where ex0.satisfied + and n0.id <> n1.id) +select edges_to_path(variadic ep0)::pathcomposite as p +from s0 +limit 1000; diff --git a/packages/go/cypher/models/pgsql/translate/expansion.go b/packages/go/cypher/models/pgsql/translate/expansion.go index 872b66c2bb..e819d7cb07 100644 --- a/packages/go/cypher/models/pgsql/translate/expansion.go +++ b/packages/go/cypher/models/pgsql/translate/expansion.go @@ -320,7 +320,7 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath}, pgsql.NewLiteral(1, pgsql.Int8), }, - CastType: pgsql.Int4, + CastType: pgsql.Int, }, }, }, @@ -417,8 +417,15 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave ), } - // Make sure to only accept paths that are satisfied - expansion.ProjectionStatement.Where = pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionSatisfied} + // Constraints that target the terminal node may crop up here where it's finally in scope. Additionally, + // only accept paths that are marked satisfied from the recursive descent CTE + if constraints, err := consumeConstraintsFrom(traversalStep.Expansion.Value.Frame.Visible, s.treeTranslator.IdentifierConstraints); err != nil { + return pgsql.Query{}, err + } else if projectionConstraints, err := ConjoinExpressions([]pgsql.Expression{pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionSatisfied}, constraints.Expression}); err != nil { + return pgsql.Query{}, err + } else { + expansion.ProjectionStatement.Where = projectionConstraints + } } } else { expansion.PrimerStatement.Projection = []pgsql.SelectItem{ @@ -653,7 +660,7 @@ func (s *Translator) buildExpansionPatternRoot(part *PatternPart, traversalStep pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath}, pgsql.NewLiteral(1, pgsql.Int8), }, - CastType: pgsql.Int4, + CastType: pgsql.Int, }, }, }, @@ -940,7 +947,7 @@ func (s *Translator) buildExpansionPatternStep(part *PatternPart, traversalStep pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath}, pgsql.NewLiteral(1, pgsql.Int8), }, - CastType: pgsql.Int4, + CastType: pgsql.Int, }, }, }, diff --git a/packages/go/cypher/models/pgsql/translate/expression.go b/packages/go/cypher/models/pgsql/translate/expression.go index 1da57810e6..1fdd24a364 100644 --- a/packages/go/cypher/models/pgsql/translate/expression.go +++ b/packages/go/cypher/models/pgsql/translate/expression.go @@ -31,8 +31,15 @@ type PropertyLookup struct { } func asPropertyLookup(expression pgsql.Expression) (*pgsql.BinaryExpression, bool) { - if binaryExpression, isBinaryExpression := expression.(*pgsql.BinaryExpression); isBinaryExpression { - return binaryExpression, pgsql.OperatorIsPropertyLookup(binaryExpression.Operator) + switch typedExpression := expression.(type) { + case pgsql.AnyExpression: + // This is here to unwrap Any expressions that have been passed in as a property lookup. This is + // common when dealing with array operators. In the future this check should be handled by the + // caller to simplify the logic here. + return asPropertyLookup(typedExpression.Expression) + + case *pgsql.BinaryExpression: + return typedExpression, pgsql.OperatorIsPropertyLookup(typedExpression.Operator) } return nil, false @@ -64,10 +71,17 @@ func ExtractSyntaxNodeReferences(root pgsql.SyntaxNode) (*pgsql.IdentifierSet, e func(node pgsql.SyntaxNode, errorHandler walk.CancelableErrorHandler) { switch typedNode := node.(type) { case pgsql.Identifier: - dependencies.Add(typedNode) + // Filter for reserved identifiers + if !pgsql.IsReservedIdentifier(typedNode) { + dependencies.Add(typedNode) + } case pgsql.CompoundIdentifier: - dependencies.Add(typedNode.Root()) + identifier := typedNode.Root() + + if !pgsql.IsReservedIdentifier(identifier) { + dependencies.Add(identifier) + } } }, )) @@ -83,7 +97,6 @@ func applyUnaryExpressionTypeHints(expression *pgsql.UnaryExpression) error { func rewritePropertyLookupOperator(propertyLookup *pgsql.BinaryExpression, dataType pgsql.DataType) pgsql.Expression { if dataType.IsArrayType() { - // This property lookup needs to be coerced into an array type using a function return pgsql.FunctionCall{ Function: pgsql.FunctionJSONBToTextArray, Parameters: []pgsql.Expression{propertyLookup}, @@ -126,7 +139,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy if isLeftHinted { if isRightHinted { - if higherLevelHint, matchesOrConverts := leftHint.Convert(rightHint); !matchesOrConverts { + if higherLevelHint, matchesOrConverts := leftHint.Compatible(rightHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, rightHint) } else { return higherLevelHint, nil @@ -136,44 +149,52 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy } else if inferredRightHint == pgsql.UnknownDataType { // Assume the right side is convertable and return the left operand hint return leftHint, nil - } else if upcastHint, matchesOrConverts := leftHint.Convert(inferredRightHint); !matchesOrConverts { + } else if upcastHint, matchesOrConverts := leftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredRightHint) } else { return upcastHint, nil } } else if isRightHinted { // There's no left type, attempt to infer it - if inferredLeftHint, err := InferExpressionType(expression.ROperand); err != nil { + if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil { return pgsql.UnsetDataType, err } else if inferredLeftHint == pgsql.UnknownDataType { // Assume the right side is convertable and return the left operand hint return rightHint, nil - } else if upcastHint, matchesOrConverts := rightHint.Convert(inferredLeftHint); !matchesOrConverts { - return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredLeftHint) + } else if upcastHint, matchesOrConverts := rightHint.Compatible(inferredLeftHint, expression.Operator); !matchesOrConverts { + return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, rightHint, inferredLeftHint) } else { return upcastHint, nil } - } else if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil { - return pgsql.UnsetDataType, err - } else if inferredRightHint, err := InferExpressionType(expression.ROperand); err != nil { - return pgsql.UnsetDataType, err - } else if inferredLeftHint == pgsql.UnknownDataType && inferredRightHint == pgsql.UnknownDataType { - // If neither side has type information then check the operator to see if it implies some type hinting + } else { + // If neither side has specific type information then check the operator to see if it implies some type + // hinting before resorting to inference switch expression.Operator { case pgsql.OperatorStartsWith, pgsql.OperatorContains, pgsql.OperatorEndsWith: // String operations imply the operands must be text return pgsql.Text, nil - // TODO: Boolean inference for OperatorAnd and OperatorOr may want to be plumbed here + case pgsql.OperatorAnd, pgsql.OperatorOr: + // Boolean operators that the operands must be boolean + return pgsql.Boolean, nil default: - // Unable to infer any type information - return pgsql.UnknownDataType, nil + // The operator does not imply specific type information onto the operands. Attempt to infer any + // information as a last ditch effort to type the AST nodes + if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil { + return pgsql.UnsetDataType, err + } else if inferredRightHint, err := InferExpressionType(expression.ROperand); err != nil { + return pgsql.UnsetDataType, err + } else if inferredLeftHint == pgsql.UnknownDataType && inferredRightHint == pgsql.UnknownDataType { + // Unable to infer any type information, this may be resolved elsewhere so this is not explicitly + // an error condition + return pgsql.UnknownDataType, nil + } else if higherLevelHint, matchesOrConverts := inferredLeftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts { + return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, inferredLeftHint, inferredRightHint) + } else { + return higherLevelHint, nil + } } - } else if higherLevelHint, matchesOrConverts := inferredLeftHint.Convert(inferredRightHint); !matchesOrConverts { - return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredLeftHint) - } else { - return higherLevelHint, nil } } @@ -191,7 +212,7 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { // Infer type information for well known column names switch typedExpression[1] { case pgsql.ColumnGraphID, pgsql.ColumnID, pgsql.ColumnStartID, pgsql.ColumnEndID: - return pgsql.Int4, nil + return pgsql.Int8, nil case pgsql.ColumnKindID: return pgsql.Int2, nil @@ -215,13 +236,18 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { // This is unknown, not unset meaning that it can be re-cast by future inference inspections return pgsql.UnknownDataType, nil - case pgsql.OperatorAnd, pgsql.OperatorOr: + case pgsql.OperatorAnd, pgsql.OperatorOr, pgsql.OperatorEquals, pgsql.OperatorGreaterThan, pgsql.OperatorGreaterThanOrEqualTo, + pgsql.OperatorLessThan, pgsql.OperatorLessThanOrEqualTo, pgsql.OperatorIn, pgsql.OperatorJSONBFieldExists, + pgsql.OperatorLike, pgsql.OperatorILike, pgsql.OperatorPGArrayOverlap: return pgsql.Boolean, nil default: return inferBinaryExpressionType(typedExpression) } + case pgsql.Parenthetical: + return InferExpressionType(typedExpression.Expression) + default: log.Infof("unable to infer type hint for expression type: %T", expression) return pgsql.UnknownDataType, nil @@ -263,31 +289,65 @@ func TypeCastExpression(expression pgsql.Expression, dataType pgsql.DataType) (p return pgsql.NewTypeCast(expression, dataType), nil } -func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression, expressionTypeHint pgsql.DataType) error { - if leftPropertyLookup, isPropertyLookup := asPropertyLookup(expression.LOperand); isPropertyLookup { - if lookupRequiresElementType(expressionTypeHint, expression.Operator, expression.ROperand) { - // Take the base type of the array type hint: in - if arrayBaseType, err := expressionTypeHint.ArrayBaseType(); err != nil { - return err - } else { - expressionTypeHint = arrayBaseType - } - } +func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { + var ( + leftPropertyLookup, hasLeftPropertyLookup = asPropertyLookup(expression.LOperand) + rightPropertyLookup, hasRightPropertyLookup = asPropertyLookup(expression.ROperand) + ) - expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, expressionTypeHint) + // Don't rewrite direct property comparisons + if hasLeftPropertyLookup && hasRightPropertyLookup { + return nil } - if rightPropertyLookup, isPropertyLookup := asPropertyLookup(expression.ROperand); isPropertyLookup { - if lookupRequiresElementType(expressionTypeHint, expression.Operator, expression.LOperand) { - // Take the base type of the array type hint: in - if arrayBaseType, err := expressionTypeHint.ArrayBaseType(); err != nil { + if hasLeftPropertyLookup { + // This check exists here to prevent from overwriting a property lookup that's part of a in + // binary expression. This may want for better ergonomics in the future + if anyExpression, isAnyExpression := expression.ROperand.(pgsql.AnyExpression); isAnyExpression { + if arrayBaseType, err := anyExpression.CastType.ArrayBaseType(); err != nil { return err } else { - expressionTypeHint = arrayBaseType + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType) + } + } else if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil { + return err + } else { + switch expression.Operator { + case pgsql.OperatorIn: + if arrayBaseType, err := rOperandTypeHint.ArrayBaseType(); err != nil { + return err + } else { + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType) + } + + case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch: + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, pgsql.Text) + + default: + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, rOperandTypeHint) } } + } + + if hasRightPropertyLookup { + if lOperandTypeHint, err := InferExpressionType(expression.LOperand); err != nil { + return err + } else { + switch expression.Operator { + case pgsql.OperatorIn: + if arrayType, err := lOperandTypeHint.ToArrayType(); err != nil { + return err + } else { + expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, arrayType) + } + + case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch: + expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, pgsql.Text) - expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, expressionTypeHint) + default: + expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, lOperandTypeHint) + } + } } return nil @@ -301,11 +361,7 @@ func applyBinaryExpressionTypeHints(expression *pgsql.BinaryExpression) error { return nil } - if expressionTypeHint, err := InferExpressionType(expression); err != nil { - return err - } else { - return rewritePropertyLookupOperands(expression, expressionTypeHint) - } + return rewritePropertyLookupOperands(expression) } type Builder struct { diff --git a/packages/go/cypher/models/pgsql/translate/expression_test.go b/packages/go/cypher/models/pgsql/translate/expression_test.go index 802c9eab75..1e514ce38e 100644 --- a/packages/go/cypher/models/pgsql/translate/expression_test.go +++ b/packages/go/cypher/models/pgsql/translate/expression_test.go @@ -72,7 +72,7 @@ func TestInferExpressionType(t *testing.T) { ), ), }, { - ExpectedType: pgsql.Text, + ExpectedType: pgsql.Boolean, Expression: pgsql.NewBinaryExpression( mustAsLiteral("123"), pgsql.OperatorIn, diff --git a/packages/go/cypher/models/pgsql/translate/projection.go b/packages/go/cypher/models/pgsql/translate/projection.go index 6bf21f4920..5b3259dfbd 100644 --- a/packages/go/cypher/models/pgsql/translate/projection.go +++ b/packages/go/cypher/models/pgsql/translate/projection.go @@ -197,7 +197,7 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope * pgsql.OperatorConcatenate, pgsql.ArrayLiteral{ Values: edgeReferences, - CastType: pgsql.Int4Array, + CastType: pgsql.Int8Array, }, ) } diff --git a/packages/go/cypher/models/pgsql/translate/translation.go b/packages/go/cypher/models/pgsql/translate/translation.go index 03eb58cd20..5bc71308a9 100644 --- a/packages/go/cypher/models/pgsql/translate/translation.go +++ b/packages/go/cypher/models/pgsql/translate/translation.go @@ -60,6 +60,76 @@ func (s *Translator) translateRemoveItem(removeItem *cypher.RemoveItem) error { return nil } +func (s *Translator) translatePropertyLookup(lookup *cypher.PropertyLookup) { + if translatedAtom, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + switch typedTranslatedAtom := translatedAtom.(type) { + case pgsql.Identifier: + if fieldIdentifierLiteral, err := pgsql.AsLiteral(lookup.Symbols[0]); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.CompoundIdentifier{typedTranslatedAtom, pgsql.ColumnProperties}) + s.treeTranslator.Push(fieldIdentifierLiteral) + + if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorPropertyLookup); err != nil { + s.SetError(err) + } + } + + case pgsql.FunctionCall: + if fieldIdentifierLiteral, err := pgsql.AsLiteral(lookup.Symbols[0]); err != nil { + s.SetError(err) + } else if componentName, typeOK := fieldIdentifierLiteral.Value.(string); !typeOK { + s.SetErrorf("expected a string component name in translated literal but received type: %T", fieldIdentifierLiteral.Value) + } else { + switch typedTranslatedAtom.Function { + case pgsql.FunctionCurrentDate, pgsql.FunctionLocalTime, pgsql.FunctionCurrentTime, pgsql.FunctionLocalTimestamp, pgsql.FunctionNow: + switch componentName { + case cypher.ITTCEpochSeconds: + s.treeTranslator.Push(pgsql.FunctionCall{ + Function: pgsql.FunctionExtract, + Parameters: []pgsql.Expression{pgsql.ProjectionFrom{ + Projection: []pgsql.SelectItem{ + pgsql.EpochIdentifier, + }, + From: []pgsql.FromClause{{ + Source: translatedAtom, + }}, + }}, + CastType: pgsql.Numeric, + }) + + case cypher.ITTCEpochMilliseconds: + s.treeTranslator.Push(pgsql.NewBinaryExpression( + pgsql.FunctionCall{ + Function: pgsql.FunctionExtract, + Parameters: []pgsql.Expression{pgsql.ProjectionFrom{ + Projection: []pgsql.SelectItem{ + pgsql.EpochIdentifier, + }, + From: []pgsql.FromClause{{ + Source: translatedAtom, + }}, + }}, + CastType: pgsql.Numeric, + }, + pgsql.OperatorMultiply, + pgsql.NewLiteral(1000, pgsql.Int4), + )) + + default: + s.SetErrorf("unsupported date time instant type component %s from function call %s", componentName, typedTranslatedAtom.Function) + } + + default: + s.SetErrorf("unsupported instant type component %s from function call %s", componentName, typedTranslatedAtom.Function) + } + } + } + } +} + func (s *Translator) translateSetItem(setItem *cypher.SetItem) error { if operator, err := translateCypherAssignmentOperator(setItem.Operator); err != nil { return err diff --git a/packages/go/cypher/models/pgsql/translate/translator.go b/packages/go/cypher/models/pgsql/translate/translator.go index 68b057f5f6..02fa92daca 100644 --- a/packages/go/cypher/models/pgsql/translate/translator.go +++ b/packages/go/cypher/models/pgsql/translate/translator.go @@ -128,7 +128,7 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { case *cypher.RegularQuery, *cypher.SingleQuery, *cypher.PatternElement, *cypher.Return, *cypher.Comparison, *cypher.Skip, *cypher.Limit, cypher.Operator, *cypher.ArithmeticExpression, *cypher.NodePattern, *cypher.RelationshipPattern, *cypher.Remove, *cypher.Set, - *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression: + *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression, *cypher.PropertyLookup: // No operation for these syntax nodes case *cypher.Negation: @@ -259,27 +259,6 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) { case *cypher.FunctionInvocation: s.pushState(StateTranslatingNestedExpression) - case *cypher.PropertyLookup: - if variable, isVariable := typedExpression.Atom.(*cypher.Variable); !isVariable { - s.SetErrorf("expected variable for property lookup reference but found type: %T", typedExpression.Atom) - } else if resolved, isResolved := s.query.Scope.LookupString(variable.Symbol); !isResolved { - s.SetErrorf("unable to resolve identifier: %s", variable.Symbol) - } else { - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - // TODO: Cypher does not support nested property references so the Symbols slice should be a string - if fieldIdentifierLiteral, err := pgsql.AsLiteral(typedExpression.Symbols[0]); err != nil { - s.SetError(err) - } else { - s.treeTranslator.Push(pgsql.CompoundIdentifier{resolved.Identifier, pgsql.ColumnProperties}) - s.treeTranslator.Push(fieldIdentifierLiteral) - } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) - } - } - case *cypher.Order: s.pushState(StateTranslatingOrderBy) @@ -610,6 +589,31 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { }) } + case cypher.ListSizeFunction: + if typedExpression.NumArguments() > 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + var functionCall pgsql.FunctionCall + + if _, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + functionCall = pgsql.FunctionCall{ + Function: pgsql.FunctionJSONBArrayLength, + Parameters: []pgsql.Expression{argument}, + CastType: pgsql.Int, + } + } else { + functionCall = pgsql.FunctionCall{ + Function: pgsql.FunctionArrayLength, + Parameters: []pgsql.Expression{argument, pgsql.NewLiteral(1, pgsql.Int)}, + CastType: pgsql.Int, + } + } + + s.treeTranslator.Push(functionCall) + } + case cypher.ToUpperFunction: if typedExpression.NumArguments() > 1 { s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) @@ -628,6 +632,24 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { }) } + case cypher.ToStringFunction: + if typedExpression.NumArguments() > 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Text)) + } + + case cypher.ToIntegerFunction: + if typedExpression.NumArguments() > 1 { + s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name)) + } else if argument, err := s.treeTranslator.Pop(); err != nil { + s.SetError(err) + } else { + s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Int8)) + } + default: s.SetErrorf("unknown cypher function: %s", typedExpression.Name) } @@ -715,24 +737,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } case *cypher.PropertyLookup: - switch currentState := s.currentState(); currentState { - case StateTranslatingNestedExpression: - if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorPropertyLookup); err != nil { - s.SetError(err) - } - - case StateTranslatingProjection: - if nextExpression, err := s.treeTranslator.Pop(); err != nil { - s.SetError(err) - } else if selectItem, isProjection := nextExpression.(pgsql.SelectItem); !isProjection { - s.SetErrorf("invalid type for select item: %T", nextExpression) - } else { - s.projections.CurrentProjection().SelectItem = selectItem - } - - default: - s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression) - } + s.translatePropertyLookup(typedExpression) case *cypher.PartialComparison: switch currentState := s.currentState(); currentState { diff --git a/packages/go/cypher/models/walk/walk_cypher.go b/packages/go/cypher/models/walk/walk_cypher.go index ee7d0a581a..f7f5a35c6d 100644 --- a/packages/go/cypher/models/walk/walk_cypher.go +++ b/packages/go/cypher/models/walk/walk_cypher.go @@ -36,12 +36,18 @@ func cypherSyntaxNodeSliceTypeConvert[F any, FS []F](fs FS) ([]cypher.SyntaxNode func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], error) { switch typedNode := node.(type) { // Types with no AST branches - case *cypher.RangeQuantifier, *cypher.PropertyLookup, cypher.Operator, *cypher.KindMatcher, + case *cypher.RangeQuantifier, cypher.Operator, *cypher.KindMatcher, *cypher.Limit, *cypher.Skip, graph.Kinds, *cypher.Parameter: return &Cursor[cypher.SyntaxNode]{ Node: node, }, nil + case *cypher.PropertyLookup: + return &Cursor[cypher.SyntaxNode]{ + Node: node, + Branches: []cypher.SyntaxNode{typedNode.Atom}, + }, nil + case *cypher.MapItem: return &Cursor[cypher.SyntaxNode]{ Node: node, diff --git a/packages/go/cypher/models/walk/walk_pgsql.go b/packages/go/cypher/models/walk/walk_pgsql.go index 65df9c7cce..29924ae140 100644 --- a/packages/go/cypher/models/walk/walk_pgsql.go +++ b/packages/go/cypher/models/walk/walk_pgsql.go @@ -321,6 +321,16 @@ func newSQLWalkCursor(node pgsql.SyntaxNode) (*Cursor[pgsql.SyntaxNode], error) Branches: []pgsql.SyntaxNode{typedNode.Subquery}, }, nil + case pgsql.ProjectionFrom: + if branches, err := pgsqlSyntaxNodeSliceTypeConvert(typedNode.From); err != nil { + return nil, err + } else { + return &Cursor[pgsql.SyntaxNode]{ + Node: node, + Branches: append([]pgsql.SyntaxNode{typedNode.Projection}, branches...), + }, nil + } + default: return nil, fmt.Errorf("unable to negotiate sql type %T into a translation cursor", node) } diff --git a/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql index 3ba6038445..3876967d2c 100644 --- a/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql +++ b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql @@ -190,7 +190,6 @@ execute procedure delete_node_edges(); alter table edge alter column properties set storage main; - -- Index on the graph ID of each edge. create index if not exists edge_graph_id_index on edge using btree (graph_id); diff --git a/packages/go/dawgs/drivers/pg/types.go b/packages/go/dawgs/drivers/pg/types.go index a0305b84af..3ce3ae9f01 100644 --- a/packages/go/dawgs/drivers/pg/types.go +++ b/packages/go/dawgs/drivers/pg/types.go @@ -60,11 +60,76 @@ func castMapValueAsSliceOf[T any](compositeMap map[string]any, key string) ([]T, func castAndAssignMapValue[T any](compositeMap map[string]any, key string, dst *T) error { if src, hasKey := compositeMap[key]; !hasKey { return fmt.Errorf("composite map does not contain expected key %s", key) - } else if typed, typeOK := src.(T); !typeOK { - var empty T - return fmt.Errorf("expected type %T but received %T", empty, src) } else { - *dst = typed + switch typedSrc := src.(type) { + case int8: + switch typedDst := any(dst).(type) { + case *int8: + *typedDst = typedSrc + case *int16: + *typedDst = int16(typedSrc) + case *int32: + *typedDst = int32(typedSrc) + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int16: + switch typedDst := any(dst).(type) { + case *int16: + *typedDst = typedSrc + case *int32: + *typedDst = int32(typedSrc) + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int32: + switch typedDst := any(dst).(type) { + case *int32: + *typedDst = typedSrc + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int64: + switch typedDst := any(dst).(type) { + case *int64: + *typedDst = typedSrc + case *int: + *typedDst = int(typedSrc) + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case int: + switch typedDst := any(dst).(type) { + case *int64: + *typedDst = int64(typedSrc) + case *int: + *typedDst = typedSrc + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } + + case T: + *dst = typedSrc + + default: + return fmt.Errorf("unable to cast and assign value type: %T", src) + } } return nil