diff --git a/cmd/api/src/analysis/ad/adcs_integration_test.go b/cmd/api/src/analysis/ad/adcs_integration_test.go
index 39e9bd9213..5a85fe3ed2 100644
--- a/cmd/api/src/analysis/ad/adcs_integration_test.go
+++ b/cmd/api/src/analysis/ad/adcs_integration_test.go
@@ -447,63 +447,75 @@ func TestTrustedForNTAuth(t *testing.T) {
func TestEnrollOnBehalfOf(t *testing.T) {
testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema())
testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
- harness.EnrollOnBehalfOfHarnessOne.Setup(testContext)
+ harness.EnrollOnBehalfOfHarness1.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate)
v1Templates := make([]*graph.Node, 0)
+ v2Templates := make([]*graph.Node, 0)
+
for _, template := range certTemplates {
if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil {
continue
} else if version == 1 {
v1Templates = append(v1Templates, template)
} else if version >= 2 {
- continue
+ v2Templates = append(v2Templates, template)
}
}
require.Nil(t, err)
db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
- results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates)
+ results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1)
require.Nil(t, err)
require.Len(t, results, 3)
require.Contains(t, results, analysis.CreatePostRelationshipJob{
- FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate11.ID,
- ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
+ FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate11.ID,
+ ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})
require.Contains(t, results, analysis.CreatePostRelationshipJob{
- FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate13.ID,
- ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
+ FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate13.ID,
+ ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})
require.Contains(t, results, analysis.CreatePostRelationshipJob{
- FromID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
- ToID: harness.EnrollOnBehalfOfHarnessOne.CertTemplate12.ID,
+ FromID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
+ ToID: harness.EnrollOnBehalfOfHarness1.CertTemplate12.ID,
Kind: ad.EnrollOnBehalfOf,
})
return nil
})
+
+ db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
+ results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness1.Domain1)
+ require.Nil(t, err)
+
+ require.Len(t, results, 0)
+
+ return nil
+ })
})
testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
- harness.EnrollOnBehalfOfHarnessTwo.Setup(testContext)
+ harness.EnrollOnBehalfOfHarness2.Setup(testContext)
return nil
}, func(harness integration.HarnessDetails, db graph.Database) {
certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate)
+ v1Templates := make([]*graph.Node, 0)
v2Templates := make([]*graph.Node, 0)
for _, template := range certTemplates {
if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil {
continue
} else if version == 1 {
- continue
+ v1Templates = append(v1Templates, template)
} else if version >= 2 {
v2Templates = append(v2Templates, template)
}
@@ -512,15 +524,60 @@ func TestEnrollOnBehalfOf(t *testing.T) {
require.Nil(t, err)
db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
- results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates)
+ results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2)
+ require.Nil(t, err)
+
+ require.Len(t, results, 0)
+ return nil
+ })
+
+ db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
+ results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates, harness.EnrollOnBehalfOfHarness2.Domain2)
require.Nil(t, err)
require.Len(t, results, 1)
require.Contains(t, results, analysis.CreatePostRelationshipJob{
- FromID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate21.ID,
- ToID: harness.EnrollOnBehalfOfHarnessTwo.CertTemplate23.ID,
+ FromID: harness.EnrollOnBehalfOfHarness2.CertTemplate21.ID,
+ ToID: harness.EnrollOnBehalfOfHarness2.CertTemplate23.ID,
Kind: ad.EnrollOnBehalfOf,
})
+ return nil
+ })
+ })
+
+ testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error {
+ harness.EnrollOnBehalfOfHarness3.Setup(testContext)
+ return nil
+ }, func(harness integration.HarnessDetails, db graph.Database) {
+ operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - EnrollOnBehalfOf 3")
+
+ _, enterpriseCertAuthorities, certTemplates, domains, cache, err := FetchADCSPrereqs(db)
+ require.Nil(t, err)
+
+ if err := ad2.PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates, cache, operation); err != nil {
+ t.Logf("failed post processing for %s: %v", ad.EnrollOnBehalfOf.String(), err)
+ }
+ err = operation.Done()
+ require.Nil(t, err)
+
+ db.ReadTransaction(context.Background(), func(tx graph.Transaction) error {
+ if startNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria {
+ return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf)
+ })); err != nil {
+ t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err)
+ } else if endNodes, err := ops.FetchStartNodes(tx.Relationships().Filterf(func() graph.Criteria {
+ return query.Kind(query.Relationship(), ad.EnrollOnBehalfOf)
+ })); err != nil {
+ t.Fatalf("error fetching EnrollOnBehalfOf edges in integration test; %v", err)
+ } else {
+ require.Len(t, startNodes, 2)
+ require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate11))
+ require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))
+
+ require.Len(t, endNodes, 2)
+ require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))
+ require.True(t, startNodes.Contains(harness.EnrollOnBehalfOfHarness3.CertTemplate12))
+ }
return nil
})
diff --git a/cmd/api/src/api/v2/datapipe_integration_test.go b/cmd/api/src/api/v2/datapipe_integration_test.go
index e992a312ae..8caf3500a4 100644
--- a/cmd/api/src/api/v2/datapipe_integration_test.go
+++ b/cmd/api/src/api/v2/datapipe_integration_test.go
@@ -4,7 +4,7 @@
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
-// http://www.apache.org/licenses/LICENSE2.0
+// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//
-// SPDXLicenseIdentifier: Apache2.0
+// SPDX-License-Identifier: Apache-2.0
//go:build serial_integration
// +build serial_integration
diff --git a/cmd/api/src/database/migration/migrations/v6.3.0.sql b/cmd/api/src/database/migration/migrations/v6.3.0.sql
index b21d0632cc..c44a919ced 100644
--- a/cmd/api/src/database/migration/migrations/v6.3.0.sql
+++ b/cmd/api/src/database/migration/migrations/v6.3.0.sql
@@ -33,3 +33,6 @@ UPDATE feature_flags SET enabled = true WHERE key = 'updated_posture_page';
-- Fix users in bad state due to sso bug
DELETE FROM auth_secrets WHERE id IN (SELECT auth_secrets.id FROM auth_secrets JOIN users ON users.id = auth_secrets.user_id WHERE users.sso_provider_id IS NOT NULL);
+
+-- Set the `oidc_support` feature flag to true
+UPDATE feature_flags SET enabled = true WHERE key = 'oidc_support';
\ No newline at end of file
diff --git a/cmd/api/src/test/integration/harnesses.go b/cmd/api/src/test/integration/harnesses.go
index 637aebc7c0..dede7f384a 100644
--- a/cmd/api/src/test/integration/harnesses.go
+++ b/cmd/api/src/test/integration/harnesses.go
@@ -1595,7 +1595,7 @@ func (s *ADCSESC1HarnessAuthUsers) Setup(graphTestContext *GraphTestContext) {
graphTestContext.UpdateNode(s.AuthUsers)
}
-type EnrollOnBehalfOfHarnessTwo struct {
+type EnrollOnBehalfOfHarness2 struct {
Domain2 *graph.Node
AuthStore2 *graph.Node
RootCA2 *graph.Node
@@ -1604,10 +1604,9 @@ type EnrollOnBehalfOfHarnessTwo struct {
CertTemplate22 *graph.Node
CertTemplate23 *graph.Node
CertTemplate24 *graph.Node
- CertTemplate25 *graph.Node
}
-func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) {
+func (s *EnrollOnBehalfOfHarness2) Setup(gt *GraphTestContext) {
certRequestAgentEKU := make([]string, 0)
certRequestAgentEKU = append(certRequestAgentEKU, adAnalysis.EkuCertRequestAgent)
emptyAppPolicies := make([]string, 0)
@@ -1623,7 +1622,7 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) {
SubjectAltRequireUPN: false,
SubjectAltRequireSPN: false,
NoSecurityExtension: false,
- SchemaVersion: 1,
+ SchemaVersion: 2,
AuthorizedSignatures: 0,
EffectiveEKUs: certRequestAgentEKU,
ApplicationPolicies: emptyAppPolicies,
@@ -1635,7 +1634,7 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) {
SubjectAltRequireUPN: false,
SubjectAltRequireSPN: false,
NoSecurityExtension: false,
- SchemaVersion: 1,
+ SchemaVersion: 2,
AuthorizedSignatures: 0,
EffectiveEKUs: []string{adAnalysis.EkuCertRequestAgent, adAnalysis.EkuAnyPurpose},
ApplicationPolicies: emptyAppPolicies,
@@ -1664,18 +1663,6 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) {
EffectiveEKUs: emptyAppPolicies,
ApplicationPolicies: emptyAppPolicies,
})
- s.CertTemplate25 = gt.NewActiveDirectoryCertTemplate("certtemplate2-5", sid, CertTemplateData{
- RequiresManagerApproval: false,
- AuthenticationEnabled: false,
- EnrolleeSuppliesSubject: false,
- SubjectAltRequireUPN: false,
- SubjectAltRequireSPN: false,
- NoSecurityExtension: false,
- SchemaVersion: 1,
- AuthorizedSignatures: 1,
- EffectiveEKUs: emptyAppPolicies,
- ApplicationPolicies: emptyAppPolicies,
- })
gt.NewRelationship(s.AuthStore2, s.Domain2, ad.NTAuthStoreFor)
gt.NewRelationship(s.RootCA2, s.Domain2, ad.RootCAFor)
@@ -1685,10 +1672,9 @@ func (s *EnrollOnBehalfOfHarnessTwo) Setup(gt *GraphTestContext) {
gt.NewRelationship(s.CertTemplate22, s.EnterpriseCA2, ad.PublishedTo)
gt.NewRelationship(s.CertTemplate23, s.EnterpriseCA2, ad.PublishedTo)
gt.NewRelationship(s.CertTemplate24, s.EnterpriseCA2, ad.PublishedTo)
- gt.NewRelationship(s.CertTemplate25, s.EnterpriseCA2, ad.PublishedTo)
}
-type EnrollOnBehalfOfHarnessOne struct {
+type EnrollOnBehalfOfHarness1 struct {
Domain1 *graph.Node
AuthStore1 *graph.Node
RootCA1 *graph.Node
@@ -1698,7 +1684,7 @@ type EnrollOnBehalfOfHarnessOne struct {
CertTemplate13 *graph.Node
}
-func (s *EnrollOnBehalfOfHarnessOne) Setup(gt *GraphTestContext) {
+func (s *EnrollOnBehalfOfHarness1) Setup(gt *GraphTestContext) {
sid := RandomDomainSID()
anyPurposeEkus := make([]string, 0)
anyPurposeEkus = append(anyPurposeEkus, adAnalysis.EkuAnyPurpose)
@@ -1753,6 +1739,74 @@ func (s *EnrollOnBehalfOfHarnessOne) Setup(gt *GraphTestContext) {
gt.NewRelationship(s.CertTemplate13, s.EnterpriseCA1, ad.PublishedTo)
}
+type EnrollOnBehalfOfHarness3 struct {
+ Domain1 *graph.Node
+ AuthStore1 *graph.Node
+ RootCA1 *graph.Node
+ EnterpriseCA1 *graph.Node
+ EnterpriseCA2 *graph.Node
+ CertTemplate11 *graph.Node
+ CertTemplate12 *graph.Node
+ CertTemplate13 *graph.Node
+}
+
+func (s *EnrollOnBehalfOfHarness3) Setup(gt *GraphTestContext) {
+ sid := RandomDomainSID()
+ anyPurposeEkus := make([]string, 0)
+ anyPurposeEkus = append(anyPurposeEkus, adAnalysis.EkuAnyPurpose)
+ emptyAppPolicies := make([]string, 0)
+ s.Domain1 = gt.NewActiveDirectoryDomain("domain1", sid, false, true)
+ s.AuthStore1 = gt.NewActiveDirectoryNTAuthStore("authstore1", sid)
+ s.RootCA1 = gt.NewActiveDirectoryRootCA("rca1", sid)
+ s.EnterpriseCA1 = gt.NewActiveDirectoryEnterpriseCA("eca1", sid)
+ s.EnterpriseCA2 = gt.NewActiveDirectoryEnterpriseCA("eca2", sid)
+ s.CertTemplate11 = gt.NewActiveDirectoryCertTemplate("certtemplate1-1", sid, CertTemplateData{
+ RequiresManagerApproval: false,
+ AuthenticationEnabled: false,
+ EnrolleeSuppliesSubject: false,
+ SubjectAltRequireUPN: false,
+ SubjectAltRequireSPN: false,
+ NoSecurityExtension: false,
+ SchemaVersion: 2,
+ AuthorizedSignatures: 0,
+ EffectiveEKUs: anyPurposeEkus,
+ ApplicationPolicies: emptyAppPolicies,
+ })
+ s.CertTemplate12 = gt.NewActiveDirectoryCertTemplate("certtemplate1-2", sid, CertTemplateData{
+ RequiresManagerApproval: false,
+ AuthenticationEnabled: false,
+ EnrolleeSuppliesSubject: false,
+ SubjectAltRequireUPN: false,
+ SubjectAltRequireSPN: false,
+ NoSecurityExtension: false,
+ SchemaVersion: 1,
+ AuthorizedSignatures: 0,
+ EffectiveEKUs: anyPurposeEkus,
+ ApplicationPolicies: emptyAppPolicies,
+ })
+ s.CertTemplate13 = gt.NewActiveDirectoryCertTemplate("certtemplate1-3", sid, CertTemplateData{
+ RequiresManagerApproval: false,
+ AuthenticationEnabled: false,
+ EnrolleeSuppliesSubject: false,
+ SubjectAltRequireUPN: false,
+ SubjectAltRequireSPN: false,
+ NoSecurityExtension: false,
+ SchemaVersion: 2,
+ AuthorizedSignatures: 0,
+ EffectiveEKUs: anyPurposeEkus,
+ ApplicationPolicies: emptyAppPolicies,
+ })
+
+ gt.NewRelationship(s.AuthStore1, s.Domain1, ad.NTAuthStoreFor)
+ gt.NewRelationship(s.RootCA1, s.Domain1, ad.RootCAFor)
+ gt.NewRelationship(s.EnterpriseCA1, s.AuthStore1, ad.TrustedForNTAuth)
+ gt.NewRelationship(s.EnterpriseCA1, s.RootCA1, ad.EnterpriseCAFor)
+ gt.NewRelationship(s.EnterpriseCA2, s.RootCA1, ad.EnterpriseCAFor)
+ gt.NewRelationship(s.CertTemplate11, s.EnterpriseCA1, ad.PublishedTo)
+ gt.NewRelationship(s.CertTemplate12, s.EnterpriseCA1, ad.PublishedTo)
+ gt.NewRelationship(s.CertTemplate13, s.EnterpriseCA2, ad.PublishedTo)
+}
+
type ADCSGoldenCertHarness struct {
NTAuthStore1 *graph.Node
RootCA1 *graph.Node
@@ -8437,8 +8491,9 @@ type HarnessDetails struct {
ShortcutHarnessEveryone2 ShortcutHarnessEveryone2
ADCSESC1Harness ADCSESC1Harness
ADCSESC1HarnessAuthUsers ADCSESC1HarnessAuthUsers
- EnrollOnBehalfOfHarnessOne EnrollOnBehalfOfHarnessOne
- EnrollOnBehalfOfHarnessTwo EnrollOnBehalfOfHarnessTwo
+ EnrollOnBehalfOfHarness1 EnrollOnBehalfOfHarness1
+ EnrollOnBehalfOfHarness2 EnrollOnBehalfOfHarness2
+ EnrollOnBehalfOfHarness3 EnrollOnBehalfOfHarness3
ADCSGoldenCertHarness ADCSGoldenCertHarness
IssuedSignedByHarness IssuedSignedByHarness
EnterpriseCAForHarness EnterpriseCAForHarness
diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json
index cc51045b70..9232e6f121 100644
--- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json
+++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.json
@@ -54,10 +54,10 @@
},
"nodes": [
{
- "id": "n1",
+ "id": "n0",
"position": {
- "x": 729.9551990267428,
- "y": -4
+ "x": 675.9551990267428,
+ "y": 50
},
"caption": "Domain1",
"labels": [],
@@ -67,10 +67,10 @@
}
},
{
- "id": "n2",
+ "id": "n1",
"position": {
- "x": 129,
- "y": 273.97628342478527
+ "x": 75,
+ "y": 327.97628342478527
},
"caption": "CertTemplate1-1",
"labels": [],
@@ -83,10 +83,10 @@
}
},
{
- "id": "n3",
+ "id": "n2",
"position": {
- "x": 487.6313891898351,
- "y": -4
+ "x": 433.6313891898351,
+ "y": 50
},
"caption": "NTAuthStore1",
"labels": [],
@@ -96,10 +96,10 @@
}
},
{
- "id": "n4",
+ "id": "n3",
"position": {
- "x": 487.6313891898351,
- "y": 273.97628342478527
+ "x": 433.6313891898351,
+ "y": 327.97628342478527
},
"caption": "EnterpriseCA1",
"labels": [],
@@ -109,10 +109,10 @@
}
},
{
- "id": "n5",
+ "id": "n4",
"position": {
- "x": 230.03558347937087,
- "y": 551.9525668495705
+ "x": 176.03558347937087,
+ "y": 605.9525668495705
},
"caption": "CertTemplate1-2",
"labels": [],
@@ -125,10 +125,10 @@
}
},
{
- "id": "n6",
+ "id": "n5",
"position": {
- "x": 508.01086036954564,
- "y": 551.2045130499298
+ "x": 454.01086036954564,
+ "y": 605.2045130499298
},
"caption": "CertTemplate1-3",
"labels": [],
@@ -141,10 +141,10 @@
}
},
{
- "id": "n7",
+ "id": "n6",
"position": {
- "x": 729.9551990267428,
- "y": 273.97628342478527
+ "x": 675.9551990267428,
+ "y": 327.97628342478527
},
"caption": "RootCA1",
"labels": [],
@@ -157,64 +157,64 @@
"relationships": [
{
"id": "n0",
- "fromId": "n3",
- "toId": "n1",
+ "fromId": "n2",
+ "toId": "n0",
"type": "NTAuthStoreFor",
"properties": {},
"style": {}
},
{
"id": "n1",
- "fromId": "n4",
- "toId": "n3",
+ "fromId": "n3",
+ "toId": "n2",
"type": "TrustedForNTAuth",
"properties": {},
"style": {}
},
{
"id": "n2",
- "fromId": "n2",
- "toId": "n4",
+ "fromId": "n1",
+ "toId": "n3",
"type": "PublishedTo",
"properties": {},
"style": {}
},
{
"id": "n3",
- "fromId": "n5",
- "toId": "n4",
+ "fromId": "n4",
+ "toId": "n3",
"type": "PublishedTo",
"properties": {},
"style": {}
},
{
"id": "n4",
- "fromId": "n6",
- "toId": "n4",
+ "fromId": "n5",
+ "toId": "n3",
"type": "PublishedTo",
"properties": {},
"style": {}
},
{
"id": "n5",
- "fromId": "n7",
- "toId": "n1",
+ "fromId": "n6",
+ "toId": "n0",
"type": "RootCAFor",
"properties": {},
"style": {}
},
{
"id": "n6",
- "fromId": "n4",
- "toId": "n7",
+ "fromId": "n3",
+ "toId": "n6",
"type": "EnterpriseCAFor",
"properties": {},
"style": {}
},
{
"id": "n7",
- "fromId": "n2",
- "toId": "n5",
+ "fromId": "n1",
+ "toId": "n4",
"type": "EnrollOnBehalfOf",
"properties": {},
"style": {
@@ -223,13 +223,23 @@
},
{
"id": "n8",
- "fromId": "n6",
- "toId": "n5",
+ "fromId": "n5",
+ "toId": "n4",
"type": "EnrollOnBehalfOf",
"properties": {},
"style": {
"arrow-color": "#68ccca"
}
+ },
+ {
+ "id": "n9",
+ "type": "EnrollOnBehalfOf",
+ "style": {
+ "arrow-color": "#68ccca"
+ },
+ "properties": {},
+ "fromId": "n4",
+ "toId": "n4"
}
]
}
\ No newline at end of file
diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg
index a86d80d2f0..6fcfa6838d 100644
--- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg
+++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-1.svg
@@ -1,18 +1 @@
-
-
+
\ No newline at end of file
diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json
index ee3e2a4b43..607f78cbef 100644
--- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json
+++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.json
@@ -1,10 +1,63 @@
{
+ "style": {
+ "font-family": "sans-serif",
+ "background-color": "#ffffff",
+ "background-image": "",
+ "background-size": "100%",
+ "node-color": "#ffffff",
+ "border-width": 4,
+ "border-color": "#000000",
+ "radius": 50,
+ "node-padding": 5,
+ "node-margin": 2,
+ "outside-position": "auto",
+ "node-icon-image": "",
+ "node-background-image": "",
+ "icon-position": "inside",
+ "icon-size": 64,
+ "caption-position": "inside",
+ "caption-max-width": 200,
+ "caption-color": "#000000",
+ "caption-font-size": 50,
+ "caption-font-weight": "normal",
+ "label-position": "inside",
+ "label-display": "pill",
+ "label-color": "#000000",
+ "label-background-color": "#ffffff",
+ "label-border-color": "#000000",
+ "label-border-width": 4,
+ "label-font-size": 40,
+ "label-padding": 5,
+ "label-margin": 4,
+ "directionality": "directed",
+ "detail-position": "inline",
+ "detail-orientation": "parallel",
+ "arrow-width": 5,
+ "arrow-color": "#000000",
+ "margin-start": 5,
+ "margin-end": 5,
+ "margin-peer": 20,
+ "attachment-start": "normal",
+ "attachment-end": "normal",
+ "relationship-icon-image": "",
+ "type-color": "#000000",
+ "type-background-color": "#ffffff",
+ "type-border-color": "#000000",
+ "type-border-width": 0,
+ "type-font-size": 16,
+ "type-padding": 5,
+ "property-position": "outside",
+ "property-alignment": "colon",
+ "property-color": "#000000",
+ "property-font-size": 16,
+ "property-font-weight": "normal"
+ },
"nodes": [
{
"id": "n0",
"position": {
- "x": -569.1685177598522,
- "y": -1021.0927494329366
+ "x": 657.3454879680903,
+ "y": 50
},
"caption": "Domain2",
"labels": [],
@@ -14,10 +67,10 @@
}
},
{
- "id": "n2",
+ "id": "n1",
"position": {
- "x": -811.4923275967599,
- "y": -1021.0927494329366
+ "x": 415.0216781311826,
+ "y": 50
},
"caption": "NTAuthStore2",
"labels": [],
@@ -27,10 +80,10 @@
}
},
{
- "id": "n5",
+ "id": "n2",
"position": {
- "x": -811.4923275967599,
- "y": -743.1164660081513
+ "x": 415.0216781311826,
+ "y": 327.97628342478527
},
"caption": "EnterpriseCA2",
"labels": [],
@@ -40,10 +93,10 @@
}
},
{
- "id": "n13",
+ "id": "n3",
"position": {
- "x": -569.1685177598522,
- "y": -743.1164660081513
+ "x": 657.3454879680903,
+ "y": 327.97628342478527
},
"caption": "RootCA2",
"labels": [],
@@ -53,40 +106,42 @@
}
},
{
- "id": "n14",
+ "id": "n4",
"position": {
- "x": -1151.5140057279425,
- "y": -743.1164660081513
+ "x": 75,
+ "y": 327.97628342478527
},
"caption": "CertTemplate2-1",
"labels": [],
"properties": {
- "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\"]"
+ "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\"]",
+ "schemaversion": "2"
},
"style": {
"node-color": "#fda1ff"
}
},
{
- "id": "n15",
+ "id": "n5",
"position": {
- "x": -1151.5140057279425,
- "y": -546.8048586088046
+ "x": 75,
+ "y": 524.287890824132
},
"caption": "CertTemplate2-2",
"labels": [],
"properties": {
- "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\", \"2.5.29.37.0\"]"
+ "effectiveekus": "[\"1.3.6.1.4.1.311.20.2.1\", \"2.5.29.37.0\"]",
+ "schemaversion": "2"
},
"style": {
"node-color": "#fda1ff"
}
},
{
- "id": "n16",
+ "id": "n6",
"position": {
- "x": -981.5031666623508,
- "y": -448.6490549091318
+ "x": 245.01083906559165,
+ "y": 622.4436945238048
},
"caption": "CertTemplate2-3",
"labels": [],
@@ -101,10 +156,10 @@
}
},
{
- "id": "n17",
+ "id": "n7",
"position": {
- "x": -695.4923275967599,
- "y": -448.6490549091318
+ "x": 531.0216781311826,
+ "y": 622.4436945238048
},
"caption": "CertTemplate2-4",
"labels": [],
@@ -117,29 +172,12 @@
"style": {
"node-color": "#fda1ff"
}
- },
- {
- "id": "n18",
- "position": {
- "x": -517.02491649774,
- "y": -448.6490549091315
- },
- "caption": "CertTemplate2-5",
- "labels": [],
- "properties": {
- "effectiveekus": "[]",
- "schemaversion": "1",
- "subjectaltrequiresupn": "true"
- },
- "style": {
- "node-color": "#fda1ff"
- }
}
],
"relationships": [
{
"id": "n0",
- "fromId": "n2",
+ "fromId": "n1",
"toId": "n0",
"type": "NTAuthStoreFor",
"properties": {},
@@ -147,141 +185,69 @@
},
{
"id": "n1",
- "fromId": "n5",
- "toId": "n2",
+ "fromId": "n2",
+ "toId": "n1",
"type": "TrustedForNTAuth",
"properties": {},
"style": {}
},
{
- "id": "n9",
- "fromId": "n13",
+ "id": "n2",
+ "fromId": "n3",
"toId": "n0",
"type": "RootCAFor",
"properties": {},
"style": {}
},
{
- "id": "n10",
- "fromId": "n5",
- "toId": "n13",
+ "id": "n3",
+ "fromId": "n2",
+ "toId": "n3",
"type": "EnterpriseCAFor",
"properties": {},
"style": {}
},
{
- "id": "n11",
- "fromId": "n14",
- "toId": "n5",
- "type": "PublishedTo",
- "properties": {},
- "style": {}
- },
- {
- "id": "n12",
- "fromId": "n15",
- "toId": "n5",
+ "id": "n4",
+ "fromId": "n4",
+ "toId": "n2",
"type": "PublishedTo",
"properties": {},
"style": {}
},
{
- "id": "n13",
- "fromId": "n16",
- "toId": "n5",
+ "id": "n5",
+ "fromId": "n5",
+ "toId": "n2",
"type": "PublishedTo",
"properties": {},
"style": {}
},
{
- "id": "n14",
- "fromId": "n17",
- "toId": "n5",
+ "id": "n6",
+ "fromId": "n6",
+ "toId": "n2",
"type": "PublishedTo",
"properties": {},
"style": {}
},
{
- "id": "n15",
- "type": "EnrollOnBehalfOf",
- "style": {
- "arrow-color": "#a4dd00"
- },
- "properties": {},
- "toId": "n16",
- "fromId": "n14"
- },
- {
- "id": "n16",
- "fromId": "n18",
- "toId": "n5",
+ "id": "n7",
+ "fromId": "n7",
+ "toId": "n2",
"type": "PublishedTo",
"properties": {},
"style": {}
},
{
- "id": "n17",
- "fromId": "n18",
- "toId": "n18",
+ "id": "n8",
+ "fromId": "n4",
+ "toId": "n6",
"type": "EnrollOnBehalfOf",
"properties": {},
"style": {
- "type-color": "#4d4d4d",
"arrow-color": "#a4dd00"
}
}
- ],
- "style": {
- "font-family": "sans-serif",
- "background-color": "#ffffff",
- "background-image": "",
- "background-size": "100%",
- "node-color": "#ffffff",
- "border-width": 4,
- "border-color": "#000000",
- "radius": 50,
- "node-padding": 5,
- "node-margin": 2,
- "outside-position": "auto",
- "node-icon-image": "",
- "node-background-image": "",
- "icon-position": "inside",
- "icon-size": 64,
- "caption-position": "inside",
- "caption-max-width": 200,
- "caption-color": "#000000",
- "caption-font-size": 50,
- "caption-font-weight": "normal",
- "label-position": "inside",
- "label-display": "pill",
- "label-color": "#000000",
- "label-background-color": "#ffffff",
- "label-border-color": "#000000",
- "label-border-width": 4,
- "label-font-size": 40,
- "label-padding": 5,
- "label-margin": 4,
- "directionality": "directed",
- "detail-position": "inline",
- "detail-orientation": "parallel",
- "arrow-width": 5,
- "arrow-color": "#000000",
- "margin-start": 5,
- "margin-end": 5,
- "margin-peer": 20,
- "attachment-start": "normal",
- "attachment-end": "normal",
- "relationship-icon-image": "",
- "type-color": "#000000",
- "type-background-color": "#ffffff",
- "type-border-color": "#000000",
- "type-border-width": 0,
- "type-font-size": 16,
- "type-padding": 5,
- "property-position": "outside",
- "property-alignment": "colon",
- "property-color": "#000000",
- "property-font-size": 16,
- "property-font-weight": "normal"
- }
+ ]
}
\ No newline at end of file
diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg
index dd203c5b22..d8d8ec0fcb 100644
--- a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg
+++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-2.svg
@@ -1,18 +1 @@
-
-
+
\ No newline at end of file
diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.json b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.json
new file mode 100644
index 0000000000..e0537e9ec6
--- /dev/null
+++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.json
@@ -0,0 +1,256 @@
+{
+ "style": {
+ "font-family": "sans-serif",
+ "background-color": "#ffffff",
+ "background-image": "",
+ "background-size": "100%",
+ "node-color": "#ffffff",
+ "border-width": 4,
+ "border-color": "#000000",
+ "radius": 50,
+ "node-padding": 5,
+ "node-margin": 2,
+ "outside-position": "auto",
+ "node-icon-image": "",
+ "node-background-image": "",
+ "icon-position": "inside",
+ "icon-size": 64,
+ "caption-position": "inside",
+ "caption-max-width": 200,
+ "caption-color": "#000000",
+ "caption-font-size": 50,
+ "caption-font-weight": "normal",
+ "label-position": "inside",
+ "label-display": "pill",
+ "label-color": "#000000",
+ "label-background-color": "#ffffff",
+ "label-border-color": "#000000",
+ "label-border-width": 4,
+ "label-font-size": 40,
+ "label-padding": 5,
+ "label-margin": 4,
+ "directionality": "directed",
+ "detail-position": "inline",
+ "detail-orientation": "parallel",
+ "arrow-width": 5,
+ "arrow-color": "#000000",
+ "margin-start": 5,
+ "margin-end": 5,
+ "margin-peer": 20,
+ "attachment-start": "normal",
+ "attachment-end": "normal",
+ "relationship-icon-image": "",
+ "type-color": "#000000",
+ "type-background-color": "#ffffff",
+ "type-border-color": "#000000",
+ "type-border-width": 0,
+ "type-font-size": 16,
+ "type-padding": 5,
+ "property-position": "outside",
+ "property-alignment": "colon",
+ "property-color": "#000000",
+ "property-font-size": 16,
+ "property-font-weight": "normal"
+ },
+ "nodes": [
+ {
+ "id": "n0",
+ "position": {
+ "x": 675.9551990267428,
+ "y": 50
+ },
+ "caption": "Domain1",
+ "labels": [],
+ "properties": {},
+ "style": {
+ "node-color": "#68ccca"
+ }
+ },
+ {
+ "id": "n1",
+ "position": {
+ "x": -44.95912785436991,
+ "y": 327.97628342478527
+ },
+ "caption": "CertTemplate1-1",
+ "labels": [],
+ "properties": {
+ "schemaversion": "2",
+ "effectiveekus": "[\"2.5.29.37.0\"]"
+ },
+ "style": {
+ "node-color": "#fda1ff"
+ }
+ },
+ {
+ "id": "n2",
+ "position": {
+ "x": 433.6313891898351,
+ "y": 50
+ },
+ "caption": "NTAuthStore1",
+ "labels": [],
+ "properties": {},
+ "style": {
+ "node-color": "#7b64ff"
+ }
+ },
+ {
+ "id": "n3",
+ "position": {
+ "x": 433.6313891898351,
+ "y": 327.97628342478527
+ },
+ "caption": "EnterpriseCA1",
+ "labels": [],
+ "properties": {},
+ "style": {
+ "node-color": "#b0bc00"
+ }
+ },
+ {
+ "id": "n4",
+ "position": {
+ "x": 194.3361306677326,
+ "y": 557.3609078601627
+ },
+ "caption": "CertTemplate1-2",
+ "labels": [],
+ "properties": {
+ "schemaversion": "1",
+ "effectiveekus": "[\"2.5.29.37.0\"]"
+ },
+ "style": {
+ "node-color": "#fda1ff"
+ }
+ },
+ {
+ "id": "n5",
+ "position": {
+ "x": 433.6313891898351,
+ "y": 796.6561663822652
+ },
+ "caption": "CertTemplate1-3",
+ "labels": [],
+ "properties": {
+ "schemaversion": "2",
+ "effectiveekus": "[\"2.5.29.37.0\"]"
+ },
+ "style": {
+ "node-color": "#fda1ff"
+ }
+ },
+ {
+ "id": "n6",
+ "position": {
+ "x": 675.9551990267428,
+ "y": 327.97628342478527
+ },
+ "caption": "RootCA1",
+ "labels": [],
+ "properties": {},
+ "style": {
+ "node-color": "#e27300"
+ }
+ },
+ {
+ "id": "n7",
+ "position": {
+ "x": 433.6313891898351,
+ "y": 500.5114025285298
+ },
+ "caption": "EnterpriseCA2",
+ "style": {
+ "node-color": "#b0bc00"
+ },
+ "labels": [],
+ "properties": {}
+ }
+ ],
+ "relationships": [
+ {
+ "id": "n0",
+ "fromId": "n2",
+ "toId": "n0",
+ "type": "NTAuthStoreFor",
+ "properties": {},
+ "style": {}
+ },
+ {
+ "id": "n1",
+ "fromId": "n3",
+ "toId": "n2",
+ "type": "TrustedForNTAuth",
+ "properties": {},
+ "style": {}
+ },
+ {
+ "id": "n2",
+ "fromId": "n1",
+ "toId": "n3",
+ "type": "PublishedTo",
+ "properties": {},
+ "style": {}
+ },
+ {
+ "id": "n3",
+ "fromId": "n4",
+ "toId": "n3",
+ "type": "PublishedTo",
+ "properties": {},
+ "style": {}
+ },
+ {
+ "id": "n5",
+ "fromId": "n6",
+ "toId": "n0",
+ "type": "RootCAFor",
+ "properties": {},
+ "style": {}
+ },
+ {
+ "id": "n6",
+ "fromId": "n3",
+ "toId": "n6",
+ "type": "EnterpriseCAFor",
+ "properties": {},
+ "style": {}
+ },
+ {
+ "id": "n7",
+ "fromId": "n1",
+ "toId": "n4",
+ "type": "EnrollOnBehalfOf",
+ "properties": {},
+ "style": {
+ "arrow-color": "#68ccca"
+ }
+ },
+ {
+ "id": "n12",
+ "type": "PublishedTo",
+ "fromId": "n5",
+ "toId": "n7",
+ "style": {},
+ "properties": {}
+ },
+ {
+ "id": "n13",
+ "type": "EnterpriseCAFor",
+ "fromId": "n7",
+ "toId": "n6",
+ "style": {},
+ "properties": {}
+ },
+ {
+ "id": "n14",
+ "type": "EnrollOnBehalfOf",
+ "style": {
+ "arrow-color": "#68ccca"
+ },
+ "properties": {},
+ "fromId": "n4",
+ "toId": "n4"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg
new file mode 100644
index 0000000000..d7192c6074
--- /dev/null
+++ b/cmd/api/src/test/integration/harnesses/enrollonbehalfof-3.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/packages/go/analysis/ad/adcs.go b/packages/go/analysis/ad/adcs.go
index 537b0b0770..ac1f9a76cb 100644
--- a/packages/go/analysis/ad/adcs.go
+++ b/packages/go/analysis/ad/adcs.go
@@ -47,30 +47,33 @@ func PostADCS(ctx context.Context, db graph.Database, groupExpansions impact.Pat
return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed fetching domain nodes: %w", err)
} else if step1Stats, err := postADCSPreProcessStep1(ctx, db, enterpriseCertAuthorities, rootCertAuthorities, aiaCertAuthorities, certTemplates); err != nil {
return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed adcs pre-processing step 1: %w", err)
- } else if step2Stats, err := postADCSPreProcessStep2(ctx, db, certTemplates); err != nil {
- return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed adcs pre-processing step 2: %w", err)
} else {
- operation := analysis.NewPostRelationshipOperation(ctx, db, "ADCS Post Processing")
-
- operation.Stats.Merge(step1Stats)
- operation.Stats.Merge(step2Stats)
-
var cache = NewADCSCache()
cache.BuildCache(ctx, db, enterpriseCertAuthorities, certTemplates, domains)
- for _, domain := range domains {
- innerDomain := domain
+ if step2Stats, err := postADCSPreProcessStep2(ctx, db, domains, enterpriseCertAuthorities, certTemplates, cache); err != nil {
+ return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed adcs pre-processing step 2: %w", err)
+ } else {
+
+ operation := analysis.NewPostRelationshipOperation(ctx, db, "ADCS Post Processing")
+
+ operation.Stats.Merge(step1Stats)
+ operation.Stats.Merge(step2Stats)
- for _, enterpriseCA := range enterpriseCertAuthorities {
- innerEnterpriseCA := enterpriseCA
+ for _, domain := range domains {
+ innerDomain := domain
- if cache.DoesCAChainProperlyToDomain(innerEnterpriseCA, innerDomain) {
- processEnterpriseCAWithValidCertChainToDomain(innerEnterpriseCA, innerDomain, groupExpansions, cache, operation)
+ for _, enterpriseCA := range enterpriseCertAuthorities {
+ innerEnterpriseCA := enterpriseCA
+
+ if cache.DoesCAChainProperlyToDomain(innerEnterpriseCA, innerDomain) {
+ processEnterpriseCAWithValidCertChainToDomain(innerEnterpriseCA, innerDomain, groupExpansions, cache, operation)
+ }
}
}
- }
+ return &operation.Stats, operation.Done()
- return &operation.Stats, operation.Done()
+ }
}
}
@@ -97,10 +100,10 @@ func postADCSPreProcessStep1(ctx context.Context, db graph.Database, enterpriseC
}
// postADCSPreProcessStep2 Processes the edges that are dependent on those processed in postADCSPreProcessStep1
-func postADCSPreProcessStep2(ctx context.Context, db graph.Database, certTemplates []*graph.Node) (*analysis.AtomicPostProcessingStats, error) {
+func postADCSPreProcessStep2(ctx context.Context, db graph.Database, domains, enterpriseCertAuthorities, certTemplates []*graph.Node, cache ADCSCache) (*analysis.AtomicPostProcessingStats, error) {
operation := analysis.NewPostRelationshipOperation(ctx, db, "ADCS Post Processing Step 2")
- if err := PostEnrollOnBehalfOf(certTemplates, operation); err != nil {
+ if err := PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates, cache, operation); err != nil {
operation.Done()
return &analysis.AtomicPostProcessingStats{}, fmt.Errorf("failed post processing for %s: %w", ad.EnrollOnBehalfOf.String(), err)
} else {
diff --git a/packages/go/analysis/ad/esc3.go b/packages/go/analysis/ad/esc3.go
index 2406e0081b..28abe1fcec 100644
--- a/packages/go/analysis/ad/esc3.go
+++ b/packages/go/analysis/ad/esc3.go
@@ -142,60 +142,71 @@ func PostADCSESC3(ctx context.Context, tx graph.Transaction, outC chan<- analysi
return nil
}
-func PostEnrollOnBehalfOf(certTemplates []*graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) error {
+func PostEnrollOnBehalfOf(domains, enterpriseCertAuthorities, certTemplates []*graph.Node, cache ADCSCache, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) error {
versionOneTemplates := make([]*graph.Node, 0)
versionTwoTemplates := make([]*graph.Node, 0)
-
for _, node := range certTemplates {
if version, err := node.Properties.Get(ad.SchemaVersion.String()).Float64(); errors.Is(err, graph.ErrPropertyNotFound) {
log.Warnf("Did not get schema version for cert template %d: %v", node.ID, err)
} else if err != nil {
log.Errorf("Error getting schema version for cert template %d: %v", node.ID, err)
+ } else if version == 1 {
+ versionOneTemplates = append(versionOneTemplates, node)
+ } else if version >= 2 {
+ versionTwoTemplates = append(versionTwoTemplates, node)
} else {
- if version == 1 {
- versionOneTemplates = append(versionOneTemplates, node)
- } else if version >= 2 {
- versionTwoTemplates = append(versionTwoTemplates, node)
- } else {
- log.Warnf("Got cert template %d with an invalid version %d", node.ID, version)
- }
+ log.Warnf("Got cert template %d with an invalid version %d", node.ID, version)
}
}
- operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error {
- if results, err := EnrollOnBehalfOfVersionTwo(tx, versionTwoTemplates, certTemplates); err != nil {
- return err
- } else {
- for _, result := range results {
- if !channels.Submit(ctx, outC, result) {
- return nil
- }
- }
+ for _, domain := range domains {
+ innerDomain := domain
- return nil
- }
- })
+ for _, enterpriseCA := range enterpriseCertAuthorities {
+ innerEnterpriseCA := enterpriseCA
- operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error {
- if results, err := EnrollOnBehalfOfVersionOne(tx, versionOneTemplates, certTemplates); err != nil {
- return err
- } else {
- for _, result := range results {
- if !channels.Submit(ctx, outC, result) {
+ if cache.DoesCAChainProperlyToDomain(innerEnterpriseCA, innerDomain) {
+ if publishedCertTemplates := cache.GetPublishedTemplateCache(enterpriseCA.ID); len(publishedCertTemplates) == 0 {
return nil
+ } else {
+ operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error {
+ if results, err := EnrollOnBehalfOfVersionTwo(tx, versionTwoTemplates, publishedCertTemplates, innerDomain); err != nil {
+ return err
+ } else {
+ for _, result := range results {
+ if !channels.Submit(ctx, outC, result) {
+ return nil
+ }
+ }
+
+ return nil
+ }
+ })
+
+ operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error {
+ if results, err := EnrollOnBehalfOfVersionOne(tx, versionOneTemplates, publishedCertTemplates, innerDomain); err != nil {
+ return err
+ } else {
+ for _, result := range results {
+ if !channels.Submit(ctx, outC, result) {
+ return nil
+ }
+ }
+
+ return nil
+ }
+ })
}
}
-
- return nil
}
- })
+ }
return nil
}
-func EnrollOnBehalfOfVersionTwo(tx graph.Transaction, versionTwoCertTemplates, allCertTemplates []*graph.Node) ([]analysis.CreatePostRelationshipJob, error) {
+func EnrollOnBehalfOfVersionTwo(tx graph.Transaction, versionTwoCertTemplates, publishedTemplates []*graph.Node, domainNode *graph.Node) ([]analysis.CreatePostRelationshipJob, error) {
results := make([]analysis.CreatePostRelationshipJob, 0)
- for _, certTemplateOne := range allCertTemplates {
+ for _, certTemplateOne := range publishedTemplates {
if hasBadEku, err := certTemplateHasEku(certTemplateOne, EkuAnyPurpose); errors.Is(err, graph.ErrPropertyNotFound) {
log.Warnf("Did not get EffectiveEKUs for cert template %d: %v", certTemplateOne.ID, err)
} else if err != nil {
@@ -208,12 +219,6 @@ func EnrollOnBehalfOfVersionTwo(tx graph.Transaction, versionTwoCertTemplates, a
log.Errorf("Error getting EffectiveEKUs for cert template %d: %v", certTemplateOne.ID, err)
} else if !hasEku {
continue
- } else if domainNode, err := getDomainForCertTemplate(tx, certTemplateOne); err != nil {
- log.Errorf("Error getting domain node for cert template %d: %v", certTemplateOne.ID, err)
- } else if isLinked, err := DoesCertTemplateLinkToDomain(tx, certTemplateOne, domainNode); err != nil {
- log.Errorf("Error fetching paths from cert template %d to domain: %v", certTemplateOne.ID, err)
- } else if !isLinked {
- continue
} else {
for _, certTemplateTwo := range versionTwoCertTemplates {
if certTemplateOne.ID == certTemplateTwo.ID {
@@ -260,10 +265,10 @@ func certTemplateHasEku(certTemplate *graph.Node, targetEkus ...string) (bool, e
}
}
-func EnrollOnBehalfOfVersionOne(tx graph.Transaction, versionOneCertTemplates []*graph.Node, allCertTemplates []*graph.Node) ([]analysis.CreatePostRelationshipJob, error) {
+func EnrollOnBehalfOfVersionOne(tx graph.Transaction, versionOneCertTemplates []*graph.Node, publishedTemplates []*graph.Node, domainNode *graph.Node) ([]analysis.CreatePostRelationshipJob, error) {
results := make([]analysis.CreatePostRelationshipJob, 0)
- for _, certTemplateOne := range allCertTemplates {
+ for _, certTemplateOne := range publishedTemplates {
//prefilter as much as we can first
if hasEku, err := certTemplateHasEkuOrAll(certTemplateOne, EkuCertRequestAgent, EkuAnyPurpose); errors.Is(err, graph.ErrPropertyNotFound) {
log.Warnf("Error checking ekus for certtemplate %d: %v", certTemplateOne.ID, err)
@@ -271,12 +276,6 @@ func EnrollOnBehalfOfVersionOne(tx graph.Transaction, versionOneCertTemplates []
log.Errorf("Error checking ekus for certtemplate %d: %v", certTemplateOne.ID, err)
} else if !hasEku {
continue
- } else if domainNode, err := getDomainForCertTemplate(tx, certTemplateOne); err != nil {
- log.Errorf("Error getting domain node for certtemplate %d: %v", certTemplateOne.ID, err)
- } else if hasPath, err := DoesCertTemplateLinkToDomain(tx, certTemplateOne, domainNode); err != nil {
- log.Errorf("Error fetching paths from certtemplate %d to domain: %v", certTemplateOne.ID, err)
- } else if !hasPath {
- continue
} else {
for _, certTemplateTwo := range versionOneCertTemplates {
if hasPath, err := DoesCertTemplateLinkToDomain(tx, certTemplateTwo, domainNode); err != nil {
@@ -359,19 +358,9 @@ func certTemplateHasEkuOrAll(certTemplate *graph.Node, targetEkus ...string) (bo
}
}
-func getDomainForCertTemplate(tx graph.Transaction, certTemplate *graph.Node) (*graph.Node, error) {
- if domainSid, err := certTemplate.Properties.Get(ad.DomainSID.String()).String(); err != nil {
- return &graph.Node{}, err
- } else if domainNode, err := analysis.FetchNodeByObjectID(tx, domainSid); err != nil {
- return &graph.Node{}, err
- } else {
- return domainNode, nil
- }
-}
-
func GetADCSESC3EdgeComposition(ctx context.Context, db graph.Database, edge *graph.Relationship) (graph.PathSet, error) {
/*
- MATCH p1 = (x)-[:MemberOf*0..]->()-[:GenericAll|Enroll|AllExtendedRights]->(ct1:CertTemplate)-[:PublishedTo]->(eca1:EnterpriseCA)
+ MATCH p1 = (x)-[:MemberOf*0..]->()-[:GenericAll|Enroll|AllExtendedRights]->(ct1:CertTemplate)-[:PublishedTo]->(eca1:EnterpriseCA)-[:TrustedForNTAuth]->(:NTAuthStore)-[:NTAuthStoreFor]->(d)
WHERE x.objectid = "S-1-5-21-83094068-830424655-2031507174-500"
AND d.objectid = "S-1-5-21-83094068-830424655-2031507174"
AND ct1.requiresmanagerapproval = false
@@ -483,7 +472,7 @@ func GetADCSESC3EdgeComposition(ctx context.Context, db graph.Database, edge *gr
for _, n := range startNodes.Slice() {
if err := traversalInst.BreadthFirst(ctx, traversal.Plan{
Root: n,
- Driver: ADCSESC3Path1Pattern(enterpriseCANodes).Do(func(terminal *graph.PathSegment) error {
+ Driver: ADCSESC3Path1Pattern(edge.EndID, enterpriseCANodes).Do(func(terminal *graph.PathSegment) error {
certTemplateNode := terminal.Search(func(nextSegment *graph.PathSegment) bool {
return nextSegment.Node.Kinds.ContainsOneOf(ad.CertTemplate)
})
@@ -673,7 +662,7 @@ func GetADCSESC3EdgeComposition(ctx context.Context, db graph.Database, edge *gr
return paths, nil
}
-func ADCSESC3Path1Pattern(enterpriseCAs cardinality.Duplex[uint64]) traversal.PatternContinuation {
+func ADCSESC3Path1Pattern(domainId graph.ID, enterpriseCAs cardinality.Duplex[uint64]) traversal.PatternContinuation {
return traversal.NewPattern().OutboundWithDepth(0, 0, query.And(
query.Kind(query.Relationship(), ad.MemberOf),
query.Kind(query.End(), ad.Group),
@@ -696,6 +685,14 @@ func ADCSESC3Path1Pattern(enterpriseCAs cardinality.Duplex[uint64]) traversal.Pa
query.KindIn(query.Relationship(), ad.PublishedTo),
query.InIDs(query.End(), graph.DuplexToGraphIDs(enterpriseCAs)...),
query.Kind(query.End(), ad.EnterpriseCA),
+ )).
+ Outbound(query.And(
+ query.KindIn(query.Relationship(), ad.TrustedForNTAuth),
+ query.Kind(query.End(), ad.NTAuthStore),
+ )).
+ Outbound(query.And(
+ query.KindIn(query.Relationship(), ad.NTAuthStoreFor),
+ query.Equals(query.EndID(), domainId),
))
}
diff --git a/packages/go/cypher/models/cypher/functions.go b/packages/go/cypher/models/cypher/functions.go
index 1ffed87f99..c6dce37f2b 100644
--- a/packages/go/cypher/models/cypher/functions.go
+++ b/packages/go/cypher/models/cypher/functions.go
@@ -30,4 +30,21 @@ const (
NodeLabelsFunction = "labels"
EdgeTypeFunction = "type"
StringSplitToArrayFunction = "split"
+ ToStringFunction = "tostring"
+ ToIntegerFunction = "toint"
+ ListSizeFunction = "size"
+
+ // ITTC - Instant Type; Temporal Component (https://neo4j.com/docs/cypher-manual/current/functions/temporal/)
+ ITTCYear = "year"
+ ITTCMonth = "month"
+ ITTCDay = "day"
+ ITTCHour = "hour"
+ ITTCMinute = "minute"
+ ITTCSecond = "second"
+ ITTCMillisecond = "millisecond"
+ ITTCMicrosecond = "microsecond"
+ ITTCNanosecond = "nanosecond"
+ ITTCTimeZone = "timezone"
+ ITTCEpochSeconds = "epochseconds"
+ ITTCEpochMilliseconds = "epochmillis"
)
diff --git a/packages/go/cypher/models/pgsql/format/format.go b/packages/go/cypher/models/pgsql/format/format.go
index ac2caeaf21..4280aef044 100644
--- a/packages/go/cypher/models/pgsql/format/format.go
+++ b/packages/go/cypher/models/pgsql/format/format.go
@@ -137,6 +137,12 @@ func formatValue(builder *OutputBuilder, value any) error {
case bool:
builder.Write(strconv.FormatBool(typedValue))
+ case float32:
+ builder.Write(strconv.FormatFloat(float64(typedValue), 'f', -1, 64))
+
+ case float64:
+ builder.Write(strconv.FormatFloat(typedValue, 'f', -1, 64))
+
default:
return fmt.Errorf("unsupported literal type: %T", value)
}
@@ -482,8 +488,27 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
exprStack = append(exprStack, pgsql.FormattingLiteral("not "))
}
+ case pgsql.ProjectionFrom:
+ for idx, projection := range typedNextExpr.Projection {
+ if idx > 0 {
+ builder.Write(", ")
+ }
+
+ if err := formatNode(builder, projection); err != nil {
+ return err
+ }
+ }
+
+ if len(typedNextExpr.From) > 0 {
+ builder.Write(" from ")
+
+ if err := formatFromClauses(builder, typedNextExpr.From); err != nil {
+ return err
+ }
+ }
+
default:
- return fmt.Errorf("unsupported node type: %T", nextExpr)
+ return fmt.Errorf("unable to format pgsql node type: %T", nextExpr)
}
}
diff --git a/packages/go/cypher/models/pgsql/functions.go b/packages/go/cypher/models/pgsql/functions.go
index 92bb29f247..be5267764a 100644
--- a/packages/go/cypher/models/pgsql/functions.go
+++ b/packages/go/cypher/models/pgsql/functions.go
@@ -24,6 +24,7 @@ const (
FunctionJSONBToTextArray Identifier = "jsonb_to_text_array"
FunctionJSONBArrayElementsText Identifier = "jsonb_array_elements_text"
FunctionJSONBBuildObject Identifier = "jsonb_build_object"
+ FunctionJSONBArrayLength Identifier = "jsonb_array_length"
FunctionArrayLength Identifier = "array_length"
FunctionArrayAggregate Identifier = "array_agg"
FunctionMin Identifier = "min"
@@ -41,4 +42,5 @@ const (
FunctionCount Identifier = "count"
FunctionStringToArray Identifier = "string_to_array"
FunctionEdgesToPath Identifier = "edges_to_path"
+ FunctionExtract Identifier = "extract"
)
diff --git a/packages/go/cypher/models/pgsql/identifiers.go b/packages/go/cypher/models/pgsql/identifiers.go
index 4e19ab092a..582290bd93 100644
--- a/packages/go/cypher/models/pgsql/identifiers.go
+++ b/packages/go/cypher/models/pgsql/identifiers.go
@@ -25,8 +25,23 @@ import (
const (
WildcardIdentifier Identifier = "*"
+ EpochIdentifier Identifier = "epoch"
)
+var reservedIdentifiers = []Identifier{
+ EpochIdentifier,
+}
+
+func IsReservedIdentifier(identifier Identifier) bool {
+ for _, reservedIdentifier := range reservedIdentifiers {
+ if identifier == reservedIdentifier {
+ return true
+ }
+ }
+
+ return false
+}
+
func AsOptionalIdentifier(identifier Identifier) models.Optional[Identifier] {
return models.ValueOptional(identifier)
}
diff --git a/packages/go/cypher/models/pgsql/model.go b/packages/go/cypher/models/pgsql/model.go
index acc97590b4..c890a855e0 100644
--- a/packages/go/cypher/models/pgsql/model.go
+++ b/packages/go/cypher/models/pgsql/model.go
@@ -403,9 +403,17 @@ type AnyExpression struct {
}
func NewAnyExpression(inner Expression) AnyExpression {
- return AnyExpression{
+ newAnyExpression := AnyExpression{
Expression: inner,
}
+
+ // This is a guard to prevent recursive wrapping of an expression in an Any expression
+ switch innerTypeHint := inner.(type) {
+ case TypeHinted:
+ newAnyExpression.CastType = innerTypeHint.TypeHint()
+ }
+
+ return newAnyExpression
}
func (s AnyExpression) AsExpression() Expression {
@@ -972,6 +980,19 @@ func (s Projection) NodeType() string {
return "projection"
}
+type ProjectionFrom struct {
+ Projection Projection
+ From []FromClause
+}
+
+func (s ProjectionFrom) NodeType() string {
+ return "projection from"
+}
+
+func (s ProjectionFrom) AsExpression() Expression {
+ return s
+}
+
// Select is a SQL expression that is evaluated to fetch data.
type Select struct {
Distinct bool
diff --git a/packages/go/cypher/models/pgsql/pgtypes.go b/packages/go/cypher/models/pgsql/pgtypes.go
index de4a12ba6f..8839e37971 100644
--- a/packages/go/cypher/models/pgsql/pgtypes.go
+++ b/packages/go/cypher/models/pgsql/pgtypes.go
@@ -93,6 +93,8 @@ const (
TextArray DataType = "text[]"
JSONB DataType = "jsonb"
JSONBArray DataType = "jsonb[]"
+ Numeric DataType = "numeric"
+ NumericArray DataType = "numeric[]"
Date DataType = "date"
TimeWithTimeZone DataType = "time with time zone"
TimeWithoutTimeZone DataType = "time without time zone"
@@ -109,7 +111,8 @@ const (
ExpansionTerminalNode DataType = "expansion_terminal_node"
)
-func (s DataType) Convert(other DataType) (DataType, bool) {
+// TODO: operator, while unused, is part of a refactor for this function to make it operator aware
+func (s DataType) Compatible(other DataType, operator Operator) (DataType, bool) {
if s == other {
return s, true
}
@@ -132,6 +135,12 @@ func (s DataType) Convert(other DataType) (DataType, bool) {
case Float8:
return Float8, true
+ case Float4Array:
+ return Float4, true
+
+ case Float8Array:
+ return Float8, true
+
case Text:
return Text, true
}
@@ -141,6 +150,21 @@ func (s DataType) Convert(other DataType) (DataType, bool) {
case Float4:
return Float8, true
+ case Float4Array, Float8Array:
+ return Float8, true
+
+ case Text:
+ return Text, true
+ }
+
+ case Numeric:
+ switch other {
+ case Float4, Float8, Int2, Int4, Int8:
+ return Numeric, true
+
+ case Float4Array, Float8Array, NumericArray:
+ return Numeric, true
+
case Text:
return Text, true
}
@@ -156,6 +180,15 @@ func (s DataType) Convert(other DataType) (DataType, bool) {
case Int8:
return Int8, true
+ case Int2Array:
+ return Int2, true
+
+ case Int4Array:
+ return Int4, true
+
+ case Int8Array:
+ return Int8, true
+
case Text:
return Text, true
}
@@ -168,6 +201,12 @@ func (s DataType) Convert(other DataType) (DataType, bool) {
case Int8:
return Int8, true
+ case Int2Array, Int4Array:
+ return Int4, true
+
+ case Int8Array:
+ return Int8, true
+
case Text:
return Text, true
}
@@ -177,9 +216,42 @@ func (s DataType) Convert(other DataType) (DataType, bool) {
case Int2, Int4, Int8:
return Int8, true
+ case Int2Array, Int4Array, Int8Array:
+ return Int8, true
+
+ case Text:
+ return Text, true
+ }
+
+ case Int:
+ switch other {
+ case Int2, Int4, Int:
+ return Int, true
+
+ case Int8:
+ return Int8, true
+
case Text:
return Text, true
}
+
+ case Int2Array:
+ switch other {
+ case Int2Array, Int4Array, Int8Array:
+ return other, true
+ }
+
+ case Int4Array:
+ switch other {
+ case Int4Array, Int8Array:
+ return other, true
+ }
+
+ case Float4Array:
+ switch other {
+ case Float4Array, Float8Array:
+ return other, true
+ }
}
return UnsetDataType, false
@@ -207,7 +279,7 @@ func (s DataType) MatchesOneOf(others ...DataType) bool {
func (s DataType) IsArrayType() bool {
switch s {
- case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray:
+ case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray, JSONBArray, NodeCompositeArray, EdgeCompositeArray, NumericArray:
return true
}
@@ -239,6 +311,8 @@ func (s DataType) ToArrayType() (DataType, error) {
return Float8Array, nil
case Text, TextArray:
return TextArray, nil
+ case Numeric, NumericArray:
+ return NumericArray, nil
default:
return UnknownDataType, ErrNoAvailableArrayDataType
}
@@ -258,8 +332,10 @@ func (s DataType) ArrayBaseType() (DataType, error) {
return Float8, nil
case TextArray:
return Text, nil
+ case NumericArray:
+ return Numeric, nil
default:
- return UnknownDataType, ErrNonArrayDataType
+ return s, nil
}
}
diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql
index 8c18b2ed0b..69b09a8b98 100644
--- a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql
+++ b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql
@@ -82,7 +82,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from
s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1
from s0,
node n1
- where (s0.n0).id = any (jsonb_to_text_array(n1.properties -> 'captured_ids')::int4[]))
+ where (s0.n0).id = any (jsonb_to_text_array(n1.properties -> 'captured_ids')::int8[]))
select s1.n0 as s, s1.n1 as e
from s1;
@@ -498,3 +498,66 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
where n0.properties ->> 'system_tags' like '%' || ('text')::text)
select s0.n0 as n
from s0;
+
+-- case: match (n:NodeKind1) where toString(n.functionallevel) in ['2008 R2','2012','2008','2003','2003 Interim','2000 Mixed/Native'] return n
+with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
+ from node n0
+ where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]
+ and (n0.properties -> 'functionallevel')::text = any
+ (array ['2008 R2', '2012', '2008', '2003', '2003 Interim', '2000 Mixed/Native']::text[]))
+select s0.n0 as n
+from s0;
+
+-- case: match (n:NodeKind1) where toInt(n.value) in [1, 2, 3, 4] return n
+with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
+ from node n0
+ where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]
+ and (n0.properties -> 'value')::int8 = any (array [1, 2, 3, 4]::int8[]))
+select s0.n0 as n
+from s0;
+
+-- case: match (u:NodeKind1) where u.pwdlastset < (datetime().epochseconds - (365 * 86400)) and not u.pwdlastset IN [-1.0, 0.0] return u limit 100
+with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
+ from node n0
+ where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]
+ and (n0.properties -> 'pwdlastset')::numeric <
+ (extract(epoch from now()::timestamp with time zone)::numeric - (365 * 86400))
+ and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[]))
+select s0.n0 as u
+from s0
+limit 100;
+
+-- case: match (u:NodeKind1) where u.pwdlastset < (datetime().epochmillis - (365 * 86400000)) and not u.pwdlastset IN [-1.0, 0.0] return u limit 100
+with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
+ from node n0
+ where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]
+ and (n0.properties -> 'pwdlastset')::numeric <
+ (extract(epoch from now()::timestamp with time zone)::numeric * 1000 - (365 * 86400000))
+ and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[]))
+select s0.n0 as u
+from s0
+limit 100;
+
+-- case: match (n:NodeKind1) where size(n.array_value) > 0 return n
+with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
+ from node n0
+ where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[]
+ and jsonb_array_length(n0.properties -> 'array_value')::int > 0)
+select s0.n0 as n
+from s0;
+
+-- case: match (n) where 1 in n.array return n
+with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
+ from node n0
+ where 1 = any (jsonb_to_text_array(n0.properties -> 'array')::int8[]))
+select s0.n0 as n
+from s0;
+
+-- case: match (n) where $p in n.array or $f in n.array return n
+-- cypher_params: {"p": 1, "f": "text"}
+with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0
+ from node n0
+ where @pi0::float8 = any (jsonb_to_text_array(n0.properties -> 'array')::float8[])
+ or @pi1::text = any (jsonb_to_text_array(n0.properties -> 'array')::text[]))
+select s0.n0 as n
+from s0;
diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql
index 2c85c05023..4cb3465067 100644
--- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql
+++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql
@@ -21,7 +21,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite
from edge e0
join node n0 on n0.id = e0.start_id
join node n1 on n1.id = e0.end_id)
-select edges_to_path(variadic array [(s0.e0).id]::int4[])::pathcomposite as p
+select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p
from s0;
-- case: match p = ()-[r1]->()-[r2]->(e) return e
@@ -92,7 +92,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite
edge e1
join node n2 on (n2.properties -> 'is_target')::bool and n2.id = e1.start_id
where (s0.n1).id = e1.end_id)
-select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int4[])::pathcomposite as p
+select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int8[])::pathcomposite as p
from s1;
-- case: match p = ()-[*..]->() return p limit 1
@@ -126,7 +126,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id)
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id)
select edges_to_path(variadic ep0)::pathcomposite as p
from s0
limit 1;
@@ -162,7 +162,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied),
s1 as (select s0.e0 as e0,
s0.ep0 as ep0,
@@ -174,7 +174,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
edge e1
join node n2 on n2.id = e1.end_id
where (s0.n1).id = e1.start_id)
-select edges_to_path(variadic array [(s1.e1).id]::int4[] || s1.ep0)::pathcomposite as p
+select edges_to_path(variadic array [(s1.e1).id]::int8[] || s1.ep0)::pathcomposite as p
from s1
limit 1;
@@ -220,8 +220,8 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite
ex0
join edge e1 on e1.id = any (ex0.path)
join node n1 on n1.id = ex0.root_id
- join node n2 on e1.id = ex0.path[array_length(ex0.path, 1)::int4] and n2.id = e1.end_id)
-select s1.e0 as e, edges_to_path(variadic array [(s1.e0).id]::int4[] || s1.ep0)::pathcomposite as p
+ join node n2 on e1.id = ex0.path[array_length(ex0.path, 1)::int] and n2.id = e1.end_id)
+select s1.e0 as e, edges_to_path(variadic array [(s1.e0).id]::int8[] || s1.ep0)::pathcomposite as p
from s1;
-- case: match p = (m:NodeKind1)-[:EdgeKind1]->(c:NodeKind2) where m.objectid ends with "-513" and not toUpper(c.operatingsystem) contains "SERVER" return p limit 1000
@@ -235,6 +235,6 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite
not upper(n1.properties ->> 'operatingsystem')::text like '%SERVER%' and
n1.id = e0.end_id
where e0.kind_id = any (array [11]::int2[]))
-select edges_to_path(variadic array [(s0.e0).id]::int4[])::pathcomposite as p
+select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p
from s0
limit 1000;
diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql
index e49a4e0e84..87da96d391 100644
--- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql
+++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql
@@ -45,7 +45,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id)
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id)
select s0.n0 as n, s0.n1 as e
from s0;
@@ -80,7 +80,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied)
select s0.n1 as e
from s0;
@@ -119,7 +119,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied)
select s0.n1 as e
from s0;
@@ -155,7 +155,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied)
select s0.n0 as n
from s0;
@@ -191,7 +191,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied),
s1 as (select s0.e0 as e0,
s0.ep0 as ep0,
@@ -237,7 +237,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied),
s1 as (select s0.e0 as e0,
s0.ep0 as ep0,
@@ -286,7 +286,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
ex1
join edge e2 on e2.id = any (ex1.path)
join node n2 on n2.id = ex1.root_id
- join node n3 on e2.id = ex1.path[array_length(ex1.path, 1)::int4] and n3.id = e2.end_id)
+ join node n3 on e2.id = ex1.path[array_length(ex1.path, 1)::int] and n3.id = e2.end_id)
select s2.n3 as l
from s2;
@@ -332,7 +332,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied)
select edges_to_path(variadic ep0)::pathcomposite as p
from s0
@@ -372,7 +372,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied
and n0.id <> n1.id)
select edges_to_path(variadic ep0)::pathcomposite as p
@@ -422,7 +422,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied)
select edges_to_path(variadic ep0)::pathcomposite as p
from s0;
@@ -463,7 +463,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id)
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.start_id)
select edges_to_path(variadic ep0)::pathcomposite as p
from s0
limit 10;
@@ -505,7 +505,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.start_id
where ex0.satisfied),
s1 as (with recursive ex1(root_id, next_id, depth, satisfied, is_cycle, path) as (select e1.start_id,
e1.end_id,
@@ -544,7 +544,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
ex1
join edge e1 on e1.id = any (ex1.path)
join node n1 on n1.id = ex1.root_id
- join node n2 on e1.id = ex1.path[array_length(ex1.path, 1)::int4] and n2.id = e1.start_id)
+ join node n2 on e1.id = ex1.path[array_length(ex1.path, 1)::int] and n2.id = e1.start_id)
select edges_to_path(variadic s1.ep1 || s1.ep0)::pathcomposite as p
from s1
limit 10;
@@ -593,7 +593,7 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied)
select edges_to_path(variadic ep0)::pathcomposite as p
from s0
diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql
index 5544c30586..b7035887f1 100644
--- a/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql
+++ b/packages/go/cypher/models/pgsql/test/translation_cases/shortest_paths.sql
@@ -28,7 +28,7 @@ with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path)
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id)
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id)
select edges_to_path(variadic ep0)::pathcomposite as p
from s0;
@@ -46,7 +46,28 @@ with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path)
from ex0
join edge e0 on e0.id = any (ex0.path)
join node n0 on n0.id = ex0.root_id
- join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
where ex0.satisfied)
select edges_to_path(variadic ep0)::pathcomposite as p
from s0;
+
+-- case: match p=shortestPath((n:NodeKind1)-[:EdgeKind1*1..]->(m)) where 'admin_tier_0' in split(m.system_tags, ' ') and n.objectid ends with '-513' and n<>m return p limit 1000
+-- cypher_params: {}
+-- pgsql_params: {"pi0":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select e0.start_id, e0.end_id, 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.start_id = e0.end_id, array [e0.id] from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.properties ->> 'objectid' like '%-513' and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [11]::int2[]);", "pi1":"insert into next_pathspace (root_id, next_id, depth, satisfied, is_cycle, path) select ex0.root_id, e0.end_id, ex0.depth + 1, 'admin_tier_0' = any (string_to_array(n1.properties ->> 'system_tags', ' ')::text[]), e0.id = any (ex0.path), ex0.path || e0.id from pathspace ex0 join edge e0 on e0.start_id = ex0.next_id join node n1 on n1.id = e0.end_id where ex0.depth < 5 and not ex0.is_cycle and e0.kind_id = any (array [11]::int2[]);"}
+with s0 as (with ex0(root_id, next_id, depth, satisfied, is_cycle, path)
+ as (select * from asp_harness(@pi0::text, @pi1::text, 5))
+ select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0,
+ (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite)
+ from edge e0
+ where e0.id = any (ex0.path)) as e0,
+ ex0.path as ep0,
+ (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1
+ from ex0
+ join edge e0 on e0.id = any (ex0.path)
+ join node n0 on n0.id = ex0.root_id
+ join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int] and n1.id = e0.end_id
+ where ex0.satisfied
+ and n0.id <> n1.id)
+select edges_to_path(variadic ep0)::pathcomposite as p
+from s0
+limit 1000;
diff --git a/packages/go/cypher/models/pgsql/translate/expansion.go b/packages/go/cypher/models/pgsql/translate/expansion.go
index 872b66c2bb..e819d7cb07 100644
--- a/packages/go/cypher/models/pgsql/translate/expansion.go
+++ b/packages/go/cypher/models/pgsql/translate/expansion.go
@@ -320,7 +320,7 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave
pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath},
pgsql.NewLiteral(1, pgsql.Int8),
},
- CastType: pgsql.Int4,
+ CastType: pgsql.Int,
},
},
},
@@ -417,8 +417,15 @@ func (s *Translator) buildAllShortestPathsExpansionRoot(part *PatternPart, trave
),
}
- // Make sure to only accept paths that are satisfied
- expansion.ProjectionStatement.Where = pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionSatisfied}
+ // Constraints that target the terminal node may crop up here where it's finally in scope. Additionally,
+ // only accept paths that are marked satisfied from the recursive descent CTE
+ if constraints, err := consumeConstraintsFrom(traversalStep.Expansion.Value.Frame.Visible, s.treeTranslator.IdentifierConstraints); err != nil {
+ return pgsql.Query{}, err
+ } else if projectionConstraints, err := ConjoinExpressions([]pgsql.Expression{pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionSatisfied}, constraints.Expression}); err != nil {
+ return pgsql.Query{}, err
+ } else {
+ expansion.ProjectionStatement.Where = projectionConstraints
+ }
}
} else {
expansion.PrimerStatement.Projection = []pgsql.SelectItem{
@@ -653,7 +660,7 @@ func (s *Translator) buildExpansionPatternRoot(part *PatternPart, traversalStep
pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath},
pgsql.NewLiteral(1, pgsql.Int8),
},
- CastType: pgsql.Int4,
+ CastType: pgsql.Int,
},
},
},
@@ -940,7 +947,7 @@ func (s *Translator) buildExpansionPatternStep(part *PatternPart, traversalStep
pgsql.CompoundIdentifier{traversalStep.Expansion.Value.Binding.Identifier, expansionPath},
pgsql.NewLiteral(1, pgsql.Int8),
},
- CastType: pgsql.Int4,
+ CastType: pgsql.Int,
},
},
},
diff --git a/packages/go/cypher/models/pgsql/translate/expression.go b/packages/go/cypher/models/pgsql/translate/expression.go
index 1da57810e6..1fdd24a364 100644
--- a/packages/go/cypher/models/pgsql/translate/expression.go
+++ b/packages/go/cypher/models/pgsql/translate/expression.go
@@ -31,8 +31,15 @@ type PropertyLookup struct {
}
func asPropertyLookup(expression pgsql.Expression) (*pgsql.BinaryExpression, bool) {
- if binaryExpression, isBinaryExpression := expression.(*pgsql.BinaryExpression); isBinaryExpression {
- return binaryExpression, pgsql.OperatorIsPropertyLookup(binaryExpression.Operator)
+ switch typedExpression := expression.(type) {
+ case pgsql.AnyExpression:
+ // This is here to unwrap Any expressions that have been passed in as a property lookup. This is
+ // common when dealing with array operators. In the future this check should be handled by the
+ // caller to simplify the logic here.
+ return asPropertyLookup(typedExpression.Expression)
+
+ case *pgsql.BinaryExpression:
+ return typedExpression, pgsql.OperatorIsPropertyLookup(typedExpression.Operator)
}
return nil, false
@@ -64,10 +71,17 @@ func ExtractSyntaxNodeReferences(root pgsql.SyntaxNode) (*pgsql.IdentifierSet, e
func(node pgsql.SyntaxNode, errorHandler walk.CancelableErrorHandler) {
switch typedNode := node.(type) {
case pgsql.Identifier:
- dependencies.Add(typedNode)
+ // Filter for reserved identifiers
+ if !pgsql.IsReservedIdentifier(typedNode) {
+ dependencies.Add(typedNode)
+ }
case pgsql.CompoundIdentifier:
- dependencies.Add(typedNode.Root())
+ identifier := typedNode.Root()
+
+ if !pgsql.IsReservedIdentifier(identifier) {
+ dependencies.Add(identifier)
+ }
}
},
))
@@ -83,7 +97,6 @@ func applyUnaryExpressionTypeHints(expression *pgsql.UnaryExpression) error {
func rewritePropertyLookupOperator(propertyLookup *pgsql.BinaryExpression, dataType pgsql.DataType) pgsql.Expression {
if dataType.IsArrayType() {
- // This property lookup needs to be coerced into an array type using a function
return pgsql.FunctionCall{
Function: pgsql.FunctionJSONBToTextArray,
Parameters: []pgsql.Expression{propertyLookup},
@@ -126,7 +139,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy
if isLeftHinted {
if isRightHinted {
- if higherLevelHint, matchesOrConverts := leftHint.Convert(rightHint); !matchesOrConverts {
+ if higherLevelHint, matchesOrConverts := leftHint.Compatible(rightHint, expression.Operator); !matchesOrConverts {
return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, rightHint)
} else {
return higherLevelHint, nil
@@ -136,44 +149,52 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy
} else if inferredRightHint == pgsql.UnknownDataType {
// Assume the right side is convertable and return the left operand hint
return leftHint, nil
- } else if upcastHint, matchesOrConverts := leftHint.Convert(inferredRightHint); !matchesOrConverts {
+ } else if upcastHint, matchesOrConverts := leftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts {
return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredRightHint)
} else {
return upcastHint, nil
}
} else if isRightHinted {
// There's no left type, attempt to infer it
- if inferredLeftHint, err := InferExpressionType(expression.ROperand); err != nil {
+ if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil {
return pgsql.UnsetDataType, err
} else if inferredLeftHint == pgsql.UnknownDataType {
// Assume the right side is convertable and return the left operand hint
return rightHint, nil
- } else if upcastHint, matchesOrConverts := rightHint.Convert(inferredLeftHint); !matchesOrConverts {
- return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredLeftHint)
+ } else if upcastHint, matchesOrConverts := rightHint.Compatible(inferredLeftHint, expression.Operator); !matchesOrConverts {
+ return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, rightHint, inferredLeftHint)
} else {
return upcastHint, nil
}
- } else if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil {
- return pgsql.UnsetDataType, err
- } else if inferredRightHint, err := InferExpressionType(expression.ROperand); err != nil {
- return pgsql.UnsetDataType, err
- } else if inferredLeftHint == pgsql.UnknownDataType && inferredRightHint == pgsql.UnknownDataType {
- // If neither side has type information then check the operator to see if it implies some type hinting
+ } else {
+ // If neither side has specific type information then check the operator to see if it implies some type
+ // hinting before resorting to inference
switch expression.Operator {
case pgsql.OperatorStartsWith, pgsql.OperatorContains, pgsql.OperatorEndsWith:
// String operations imply the operands must be text
return pgsql.Text, nil
- // TODO: Boolean inference for OperatorAnd and OperatorOr may want to be plumbed here
+ case pgsql.OperatorAnd, pgsql.OperatorOr:
+ // Boolean operators that the operands must be boolean
+ return pgsql.Boolean, nil
default:
- // Unable to infer any type information
- return pgsql.UnknownDataType, nil
+ // The operator does not imply specific type information onto the operands. Attempt to infer any
+ // information as a last ditch effort to type the AST nodes
+ if inferredLeftHint, err := InferExpressionType(expression.LOperand); err != nil {
+ return pgsql.UnsetDataType, err
+ } else if inferredRightHint, err := InferExpressionType(expression.ROperand); err != nil {
+ return pgsql.UnsetDataType, err
+ } else if inferredLeftHint == pgsql.UnknownDataType && inferredRightHint == pgsql.UnknownDataType {
+ // Unable to infer any type information, this may be resolved elsewhere so this is not explicitly
+ // an error condition
+ return pgsql.UnknownDataType, nil
+ } else if higherLevelHint, matchesOrConverts := inferredLeftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts {
+ return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, inferredLeftHint, inferredRightHint)
+ } else {
+ return higherLevelHint, nil
+ }
}
- } else if higherLevelHint, matchesOrConverts := inferredLeftHint.Convert(inferredRightHint); !matchesOrConverts {
- return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredLeftHint)
- } else {
- return higherLevelHint, nil
}
}
@@ -191,7 +212,7 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) {
// Infer type information for well known column names
switch typedExpression[1] {
case pgsql.ColumnGraphID, pgsql.ColumnID, pgsql.ColumnStartID, pgsql.ColumnEndID:
- return pgsql.Int4, nil
+ return pgsql.Int8, nil
case pgsql.ColumnKindID:
return pgsql.Int2, nil
@@ -215,13 +236,18 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) {
// This is unknown, not unset meaning that it can be re-cast by future inference inspections
return pgsql.UnknownDataType, nil
- case pgsql.OperatorAnd, pgsql.OperatorOr:
+ case pgsql.OperatorAnd, pgsql.OperatorOr, pgsql.OperatorEquals, pgsql.OperatorGreaterThan, pgsql.OperatorGreaterThanOrEqualTo,
+ pgsql.OperatorLessThan, pgsql.OperatorLessThanOrEqualTo, pgsql.OperatorIn, pgsql.OperatorJSONBFieldExists,
+ pgsql.OperatorLike, pgsql.OperatorILike, pgsql.OperatorPGArrayOverlap:
return pgsql.Boolean, nil
default:
return inferBinaryExpressionType(typedExpression)
}
+ case pgsql.Parenthetical:
+ return InferExpressionType(typedExpression.Expression)
+
default:
log.Infof("unable to infer type hint for expression type: %T", expression)
return pgsql.UnknownDataType, nil
@@ -263,31 +289,65 @@ func TypeCastExpression(expression pgsql.Expression, dataType pgsql.DataType) (p
return pgsql.NewTypeCast(expression, dataType), nil
}
-func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression, expressionTypeHint pgsql.DataType) error {
- if leftPropertyLookup, isPropertyLookup := asPropertyLookup(expression.LOperand); isPropertyLookup {
- if lookupRequiresElementType(expressionTypeHint, expression.Operator, expression.ROperand) {
- // Take the base type of the array type hint: in
- if arrayBaseType, err := expressionTypeHint.ArrayBaseType(); err != nil {
- return err
- } else {
- expressionTypeHint = arrayBaseType
- }
- }
+func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error {
+ var (
+ leftPropertyLookup, hasLeftPropertyLookup = asPropertyLookup(expression.LOperand)
+ rightPropertyLookup, hasRightPropertyLookup = asPropertyLookup(expression.ROperand)
+ )
- expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, expressionTypeHint)
+ // Don't rewrite direct property comparisons
+ if hasLeftPropertyLookup && hasRightPropertyLookup {
+ return nil
}
- if rightPropertyLookup, isPropertyLookup := asPropertyLookup(expression.ROperand); isPropertyLookup {
- if lookupRequiresElementType(expressionTypeHint, expression.Operator, expression.LOperand) {
- // Take the base type of the array type hint: in
- if arrayBaseType, err := expressionTypeHint.ArrayBaseType(); err != nil {
+ if hasLeftPropertyLookup {
+ // This check exists here to prevent from overwriting a property lookup that's part of a in
+ // binary expression. This may want for better ergonomics in the future
+ if anyExpression, isAnyExpression := expression.ROperand.(pgsql.AnyExpression); isAnyExpression {
+ if arrayBaseType, err := anyExpression.CastType.ArrayBaseType(); err != nil {
return err
} else {
- expressionTypeHint = arrayBaseType
+ expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType)
+ }
+ } else if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil {
+ return err
+ } else {
+ switch expression.Operator {
+ case pgsql.OperatorIn:
+ if arrayBaseType, err := rOperandTypeHint.ArrayBaseType(); err != nil {
+ return err
+ } else {
+ expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType)
+ }
+
+ case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch:
+ expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, pgsql.Text)
+
+ default:
+ expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, rOperandTypeHint)
}
}
+ }
+
+ if hasRightPropertyLookup {
+ if lOperandTypeHint, err := InferExpressionType(expression.LOperand); err != nil {
+ return err
+ } else {
+ switch expression.Operator {
+ case pgsql.OperatorIn:
+ if arrayType, err := lOperandTypeHint.ToArrayType(); err != nil {
+ return err
+ } else {
+ expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, arrayType)
+ }
+
+ case pgsql.OperatorStartsWith, pgsql.OperatorEndsWith, pgsql.OperatorContains, pgsql.OperatorRegexMatch:
+ expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, pgsql.Text)
- expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, expressionTypeHint)
+ default:
+ expression.ROperand = rewritePropertyLookupOperator(rightPropertyLookup, lOperandTypeHint)
+ }
+ }
}
return nil
@@ -301,11 +361,7 @@ func applyBinaryExpressionTypeHints(expression *pgsql.BinaryExpression) error {
return nil
}
- if expressionTypeHint, err := InferExpressionType(expression); err != nil {
- return err
- } else {
- return rewritePropertyLookupOperands(expression, expressionTypeHint)
- }
+ return rewritePropertyLookupOperands(expression)
}
type Builder struct {
diff --git a/packages/go/cypher/models/pgsql/translate/expression_test.go b/packages/go/cypher/models/pgsql/translate/expression_test.go
index 802c9eab75..1e514ce38e 100644
--- a/packages/go/cypher/models/pgsql/translate/expression_test.go
+++ b/packages/go/cypher/models/pgsql/translate/expression_test.go
@@ -72,7 +72,7 @@ func TestInferExpressionType(t *testing.T) {
),
),
}, {
- ExpectedType: pgsql.Text,
+ ExpectedType: pgsql.Boolean,
Expression: pgsql.NewBinaryExpression(
mustAsLiteral("123"),
pgsql.OperatorIn,
diff --git a/packages/go/cypher/models/pgsql/translate/projection.go b/packages/go/cypher/models/pgsql/translate/projection.go
index 6bf21f4920..5b3259dfbd 100644
--- a/packages/go/cypher/models/pgsql/translate/projection.go
+++ b/packages/go/cypher/models/pgsql/translate/projection.go
@@ -197,7 +197,7 @@ func buildProjection(alias pgsql.Identifier, projected *BoundIdentifier, scope *
pgsql.OperatorConcatenate,
pgsql.ArrayLiteral{
Values: edgeReferences,
- CastType: pgsql.Int4Array,
+ CastType: pgsql.Int8Array,
},
)
}
diff --git a/packages/go/cypher/models/pgsql/translate/translation.go b/packages/go/cypher/models/pgsql/translate/translation.go
index 03eb58cd20..5bc71308a9 100644
--- a/packages/go/cypher/models/pgsql/translate/translation.go
+++ b/packages/go/cypher/models/pgsql/translate/translation.go
@@ -60,6 +60,76 @@ func (s *Translator) translateRemoveItem(removeItem *cypher.RemoveItem) error {
return nil
}
+func (s *Translator) translatePropertyLookup(lookup *cypher.PropertyLookup) {
+ if translatedAtom, err := s.treeTranslator.Pop(); err != nil {
+ s.SetError(err)
+ } else {
+ switch typedTranslatedAtom := translatedAtom.(type) {
+ case pgsql.Identifier:
+ if fieldIdentifierLiteral, err := pgsql.AsLiteral(lookup.Symbols[0]); err != nil {
+ s.SetError(err)
+ } else {
+ s.treeTranslator.Push(pgsql.CompoundIdentifier{typedTranslatedAtom, pgsql.ColumnProperties})
+ s.treeTranslator.Push(fieldIdentifierLiteral)
+
+ if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorPropertyLookup); err != nil {
+ s.SetError(err)
+ }
+ }
+
+ case pgsql.FunctionCall:
+ if fieldIdentifierLiteral, err := pgsql.AsLiteral(lookup.Symbols[0]); err != nil {
+ s.SetError(err)
+ } else if componentName, typeOK := fieldIdentifierLiteral.Value.(string); !typeOK {
+ s.SetErrorf("expected a string component name in translated literal but received type: %T", fieldIdentifierLiteral.Value)
+ } else {
+ switch typedTranslatedAtom.Function {
+ case pgsql.FunctionCurrentDate, pgsql.FunctionLocalTime, pgsql.FunctionCurrentTime, pgsql.FunctionLocalTimestamp, pgsql.FunctionNow:
+ switch componentName {
+ case cypher.ITTCEpochSeconds:
+ s.treeTranslator.Push(pgsql.FunctionCall{
+ Function: pgsql.FunctionExtract,
+ Parameters: []pgsql.Expression{pgsql.ProjectionFrom{
+ Projection: []pgsql.SelectItem{
+ pgsql.EpochIdentifier,
+ },
+ From: []pgsql.FromClause{{
+ Source: translatedAtom,
+ }},
+ }},
+ CastType: pgsql.Numeric,
+ })
+
+ case cypher.ITTCEpochMilliseconds:
+ s.treeTranslator.Push(pgsql.NewBinaryExpression(
+ pgsql.FunctionCall{
+ Function: pgsql.FunctionExtract,
+ Parameters: []pgsql.Expression{pgsql.ProjectionFrom{
+ Projection: []pgsql.SelectItem{
+ pgsql.EpochIdentifier,
+ },
+ From: []pgsql.FromClause{{
+ Source: translatedAtom,
+ }},
+ }},
+ CastType: pgsql.Numeric,
+ },
+ pgsql.OperatorMultiply,
+ pgsql.NewLiteral(1000, pgsql.Int4),
+ ))
+
+ default:
+ s.SetErrorf("unsupported date time instant type component %s from function call %s", componentName, typedTranslatedAtom.Function)
+ }
+
+ default:
+ s.SetErrorf("unsupported instant type component %s from function call %s", componentName, typedTranslatedAtom.Function)
+ }
+ }
+ }
+ }
+}
+
func (s *Translator) translateSetItem(setItem *cypher.SetItem) error {
if operator, err := translateCypherAssignmentOperator(setItem.Operator); err != nil {
return err
diff --git a/packages/go/cypher/models/pgsql/translate/translator.go b/packages/go/cypher/models/pgsql/translate/translator.go
index 68b057f5f6..02fa92daca 100644
--- a/packages/go/cypher/models/pgsql/translate/translator.go
+++ b/packages/go/cypher/models/pgsql/translate/translator.go
@@ -128,7 +128,7 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) {
case *cypher.RegularQuery, *cypher.SingleQuery, *cypher.PatternElement, *cypher.Return,
*cypher.Comparison, *cypher.Skip, *cypher.Limit, cypher.Operator, *cypher.ArithmeticExpression,
*cypher.NodePattern, *cypher.RelationshipPattern, *cypher.Remove, *cypher.Set,
- *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression:
+ *cypher.ReadingClause, *cypher.UnaryAddOrSubtractExpression, *cypher.PropertyLookup:
// No operation for these syntax nodes
case *cypher.Negation:
@@ -259,27 +259,6 @@ func (s *Translator) Enter(expression cypher.SyntaxNode) {
case *cypher.FunctionInvocation:
s.pushState(StateTranslatingNestedExpression)
- case *cypher.PropertyLookup:
- if variable, isVariable := typedExpression.Atom.(*cypher.Variable); !isVariable {
- s.SetErrorf("expected variable for property lookup reference but found type: %T", typedExpression.Atom)
- } else if resolved, isResolved := s.query.Scope.LookupString(variable.Symbol); !isResolved {
- s.SetErrorf("unable to resolve identifier: %s", variable.Symbol)
- } else {
- switch currentState := s.currentState(); currentState {
- case StateTranslatingNestedExpression:
- // TODO: Cypher does not support nested property references so the Symbols slice should be a string
- if fieldIdentifierLiteral, err := pgsql.AsLiteral(typedExpression.Symbols[0]); err != nil {
- s.SetError(err)
- } else {
- s.treeTranslator.Push(pgsql.CompoundIdentifier{resolved.Identifier, pgsql.ColumnProperties})
- s.treeTranslator.Push(fieldIdentifierLiteral)
- }
-
- default:
- s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression)
- }
- }
-
case *cypher.Order:
s.pushState(StateTranslatingOrderBy)
@@ -610,6 +589,31 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) {
})
}
+ case cypher.ListSizeFunction:
+ if typedExpression.NumArguments() > 1 {
+ s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name))
+ } else if argument, err := s.treeTranslator.Pop(); err != nil {
+ s.SetError(err)
+ } else {
+ var functionCall pgsql.FunctionCall
+
+ if _, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup {
+ functionCall = pgsql.FunctionCall{
+ Function: pgsql.FunctionJSONBArrayLength,
+ Parameters: []pgsql.Expression{argument},
+ CastType: pgsql.Int,
+ }
+ } else {
+ functionCall = pgsql.FunctionCall{
+ Function: pgsql.FunctionArrayLength,
+ Parameters: []pgsql.Expression{argument, pgsql.NewLiteral(1, pgsql.Int)},
+ CastType: pgsql.Int,
+ }
+ }
+
+ s.treeTranslator.Push(functionCall)
+ }
+
case cypher.ToUpperFunction:
if typedExpression.NumArguments() > 1 {
s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name))
@@ -628,6 +632,24 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) {
})
}
+ case cypher.ToStringFunction:
+ if typedExpression.NumArguments() > 1 {
+ s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name))
+ } else if argument, err := s.treeTranslator.Pop(); err != nil {
+ s.SetError(err)
+ } else {
+ s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Text))
+ }
+
+ case cypher.ToIntegerFunction:
+ if typedExpression.NumArguments() > 1 {
+ s.SetError(fmt.Errorf("expected only one argument for cypher function: %s", typedExpression.Name))
+ } else if argument, err := s.treeTranslator.Pop(); err != nil {
+ s.SetError(err)
+ } else {
+ s.treeTranslator.Push(pgsql.NewTypeCast(argument, pgsql.Int8))
+ }
+
default:
s.SetErrorf("unknown cypher function: %s", typedExpression.Name)
}
@@ -715,24 +737,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) {
}
case *cypher.PropertyLookup:
- switch currentState := s.currentState(); currentState {
- case StateTranslatingNestedExpression:
- if err := s.treeTranslator.PopPushOperator(s.query.Scope, pgsql.OperatorPropertyLookup); err != nil {
- s.SetError(err)
- }
-
- case StateTranslatingProjection:
- if nextExpression, err := s.treeTranslator.Pop(); err != nil {
- s.SetError(err)
- } else if selectItem, isProjection := nextExpression.(pgsql.SelectItem); !isProjection {
- s.SetErrorf("invalid type for select item: %T", nextExpression)
- } else {
- s.projections.CurrentProjection().SelectItem = selectItem
- }
-
- default:
- s.SetErrorf("invalid state \"%s\" for cypher AST node %T", s.currentState(), expression)
- }
+ s.translatePropertyLookup(typedExpression)
case *cypher.PartialComparison:
switch currentState := s.currentState(); currentState {
diff --git a/packages/go/cypher/models/walk/walk_cypher.go b/packages/go/cypher/models/walk/walk_cypher.go
index ee7d0a581a..f7f5a35c6d 100644
--- a/packages/go/cypher/models/walk/walk_cypher.go
+++ b/packages/go/cypher/models/walk/walk_cypher.go
@@ -36,12 +36,18 @@ func cypherSyntaxNodeSliceTypeConvert[F any, FS []F](fs FS) ([]cypher.SyntaxNode
func newCypherWalkCursor(node cypher.SyntaxNode) (*Cursor[cypher.SyntaxNode], error) {
switch typedNode := node.(type) {
// Types with no AST branches
- case *cypher.RangeQuantifier, *cypher.PropertyLookup, cypher.Operator, *cypher.KindMatcher,
+ case *cypher.RangeQuantifier, cypher.Operator, *cypher.KindMatcher,
*cypher.Limit, *cypher.Skip, graph.Kinds, *cypher.Parameter:
return &Cursor[cypher.SyntaxNode]{
Node: node,
}, nil
+ case *cypher.PropertyLookup:
+ return &Cursor[cypher.SyntaxNode]{
+ Node: node,
+ Branches: []cypher.SyntaxNode{typedNode.Atom},
+ }, nil
+
case *cypher.MapItem:
return &Cursor[cypher.SyntaxNode]{
Node: node,
diff --git a/packages/go/cypher/models/walk/walk_pgsql.go b/packages/go/cypher/models/walk/walk_pgsql.go
index 65df9c7cce..29924ae140 100644
--- a/packages/go/cypher/models/walk/walk_pgsql.go
+++ b/packages/go/cypher/models/walk/walk_pgsql.go
@@ -321,6 +321,16 @@ func newSQLWalkCursor(node pgsql.SyntaxNode) (*Cursor[pgsql.SyntaxNode], error)
Branches: []pgsql.SyntaxNode{typedNode.Subquery},
}, nil
+ case pgsql.ProjectionFrom:
+ if branches, err := pgsqlSyntaxNodeSliceTypeConvert(typedNode.From); err != nil {
+ return nil, err
+ } else {
+ return &Cursor[pgsql.SyntaxNode]{
+ Node: node,
+ Branches: append([]pgsql.SyntaxNode{typedNode.Projection}, branches...),
+ }, nil
+ }
+
default:
return nil, fmt.Errorf("unable to negotiate sql type %T into a translation cursor", node)
}
diff --git a/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql
index 3ba6038445..3876967d2c 100644
--- a/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql
+++ b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql
@@ -190,7 +190,6 @@ execute procedure delete_node_edges();
alter table edge
alter column properties set storage main;
-
-- Index on the graph ID of each edge.
create index if not exists edge_graph_id_index on edge using btree (graph_id);
diff --git a/packages/go/dawgs/drivers/pg/types.go b/packages/go/dawgs/drivers/pg/types.go
index a0305b84af..3ce3ae9f01 100644
--- a/packages/go/dawgs/drivers/pg/types.go
+++ b/packages/go/dawgs/drivers/pg/types.go
@@ -60,11 +60,76 @@ func castMapValueAsSliceOf[T any](compositeMap map[string]any, key string) ([]T,
func castAndAssignMapValue[T any](compositeMap map[string]any, key string, dst *T) error {
if src, hasKey := compositeMap[key]; !hasKey {
return fmt.Errorf("composite map does not contain expected key %s", key)
- } else if typed, typeOK := src.(T); !typeOK {
- var empty T
- return fmt.Errorf("expected type %T but received %T", empty, src)
} else {
- *dst = typed
+ switch typedSrc := src.(type) {
+ case int8:
+ switch typedDst := any(dst).(type) {
+ case *int8:
+ *typedDst = typedSrc
+ case *int16:
+ *typedDst = int16(typedSrc)
+ case *int32:
+ *typedDst = int32(typedSrc)
+ case *int64:
+ *typedDst = int64(typedSrc)
+ case *int:
+ *typedDst = int(typedSrc)
+ default:
+ return fmt.Errorf("unable to cast and assign value type: %T", src)
+ }
+
+ case int16:
+ switch typedDst := any(dst).(type) {
+ case *int16:
+ *typedDst = typedSrc
+ case *int32:
+ *typedDst = int32(typedSrc)
+ case *int64:
+ *typedDst = int64(typedSrc)
+ case *int:
+ *typedDst = int(typedSrc)
+ default:
+ return fmt.Errorf("unable to cast and assign value type: %T", src)
+ }
+
+ case int32:
+ switch typedDst := any(dst).(type) {
+ case *int32:
+ *typedDst = typedSrc
+ case *int64:
+ *typedDst = int64(typedSrc)
+ case *int:
+ *typedDst = int(typedSrc)
+ default:
+ return fmt.Errorf("unable to cast and assign value type: %T", src)
+ }
+
+ case int64:
+ switch typedDst := any(dst).(type) {
+ case *int64:
+ *typedDst = typedSrc
+ case *int:
+ *typedDst = int(typedSrc)
+ default:
+ return fmt.Errorf("unable to cast and assign value type: %T", src)
+ }
+
+ case int:
+ switch typedDst := any(dst).(type) {
+ case *int64:
+ *typedDst = int64(typedSrc)
+ case *int:
+ *typedDst = typedSrc
+ default:
+ return fmt.Errorf("unable to cast and assign value type: %T", src)
+ }
+
+ case T:
+ *dst = typedSrc
+
+ default:
+ return fmt.Errorf("unable to cast and assign value type: %T", src)
+ }
}
return nil