diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index d6510be063..37fcd06fa5 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -43,7 +43,7 @@ jobs: - name: Install Go uses: actions/setup-go@v4 with: - go-version: '^1.20.0' + go-version: '^1.21.0' - name: Install Python uses: actions/setup-python@v4 diff --git a/DEVREADME.md b/DEVREADME.md index 3bef6006bb..93ee0d1ea8 100644 --- a/DEVREADME.md +++ b/DEVREADME.md @@ -9,7 +9,7 @@ More detailed information regarding [contributing](https://github.com/SpecterOps - [Just](https://github.com/casey/just) - [Python 3.10](https://www.python.org/downloads/) -- [Go 1.20](https://go.dev/dl/) +- [Go 1.21](https://go.dev/dl/) - [Node 18](https://nodejs.dev/en/download/) - [Yarn 3.6](https://yarnpkg.com/getting-started/install) - [Docker Desktop](https://www.docker.com/products/docker-desktop/) (or Docker/Docker Compose compatible runtime) diff --git a/cmd/api/src/analysis/ad/ad_integration_test.go b/cmd/api/src/analysis/ad/ad_integration_test.go index 65fcbaab62..a27400ab26 100644 --- a/cmd/api/src/analysis/ad/ad_integration_test.go +++ b/cmd/api/src/analysis/ad/ad_integration_test.go @@ -21,6 +21,8 @@ package ad_test import ( "context" + schema "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/src/test" "testing" "github.com/specterops/bloodhound/analysis" @@ -33,34 +35,36 @@ import ( ) func TestFetchEnforcedGPOs(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { // Check the first user var ( enforcedGPOs, err = adAnalysis.FetchEnforcedGPOs(tx, harness.GPOEnforcement.UserC, 0, 0) ) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, enforcedGPOs.Len()) // Check the second user enforcedGPOs, err = adAnalysis.FetchEnforcedGPOs(tx, harness.GPOEnforcement.UserB, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, enforcedGPOs.Len()) }) } func TestFetchGPOAffectedContainerPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { containers, err := adAnalysis.FetchGPOAffectedContainerPaths(tx, harness.GPOEnforcement.GPOEnforced) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := containers.AllNodes().IDs() require.Equal(t, 6, len(nodes)) require.Contains(t, nodes, harness.GPOEnforcement.GPOEnforced.ID) @@ -71,7 +75,7 @@ func TestFetchGPOAffectedContainerPaths(t *testing.T) { require.Contains(t, nodes, harness.GPOEnforcement.OrganizationalUnitD.ID) containers, err = adAnalysis.FetchGPOAffectedContainerPaths(tx, harness.GPOEnforcement.GPOUnenforced) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = containers.AllNodes().IDs() require.Equal(t, 5, len(nodes)) require.Contains(t, nodes, harness.GPOEnforcement.GPOUnenforced.ID) @@ -83,20 +87,21 @@ func TestFetchGPOAffectedContainerPaths(t *testing.T) { } func TestCreateGPOAffectedIntermediariesListDelegateAffectedContainers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { containers, err := adAnalysis.CreateGPOAffectedIntermediariesListDelegate(adAnalysis.SelectGPOContainerCandidateFilter)(tx, harness.GPOEnforcement.GPOEnforced, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 5, containers.Len()) require.Equal(t, 4, containers.ContainingNodeKinds(ad.OU).Len()) require.Equal(t, 1, containers.ContainingNodeKinds(ad.Domain).Len()) containers, err = adAnalysis.CreateGPOAffectedIntermediariesListDelegate(adAnalysis.SelectGPOContainerCandidateFilter)(tx, harness.GPOEnforcement.GPOUnenforced, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 4, containers.Len()) require.False(t, containers.Contains(harness.GPOEnforcement.OrganizationalUnitC)) require.Equal(t, 3, containers.ContainingNodeKinds(ad.OU).Len()) @@ -105,13 +110,14 @@ func TestCreateGPOAffectedIntermediariesListDelegateAffectedContainers(t *testin } func TestCreateGPOAffectedIntermediariesPathDelegateAffectedUsers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { users, err := adAnalysis.CreateGPOAffectedIntermediariesPathDelegate(ad.User)(tx, harness.GPOEnforcement.GPOEnforced) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := users.AllNodes().IDs() require.Equal(t, 10, len(nodes)) require.Contains(t, nodes, harness.GPOEnforcement.GPOEnforced.ID) @@ -122,7 +128,7 @@ func TestCreateGPOAffectedIntermediariesPathDelegateAffectedUsers(t *testing.T) require.Contains(t, nodes, harness.GPOEnforcement.OrganizationalUnitC.ID) users, err = adAnalysis.CreateGPOAffectedIntermediariesPathDelegate(ad.User)(tx, harness.GPOEnforcement.GPOUnenforced) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = users.AllNodes().IDs() require.Equal(t, 8, len(nodes)) require.Contains(t, nodes, harness.GPOEnforcement.GPOUnenforced.ID) @@ -135,27 +141,29 @@ func TestCreateGPOAffectedIntermediariesPathDelegateAffectedUsers(t *testing.T) } func TestCreateGPOAffectedResultsListDelegateAffectedUsers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { users, err := adAnalysis.CreateGPOAffectedIntermediariesListDelegate(adAnalysis.SelectUsersCandidateFilter)(tx, harness.GPOEnforcement.GPOEnforced, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 4, users.Len()) users, err = adAnalysis.CreateGPOAffectedIntermediariesListDelegate(adAnalysis.SelectUsersCandidateFilter)(tx, harness.GPOEnforcement.GPOUnenforced, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 3, users.Len()) require.Equal(t, 3, users.ContainingNodeKinds(ad.User).Len()) }) } func TestCreateGPOAffectedIntermediariesListDelegateTierZero(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { harness.GPOEnforcement.UserC.Properties.Set(common.SystemTags.String(), ad.AdminTierZero) harness.GPOEnforcement.UserD.Properties.Set(common.SystemTags.String(), ad.AdminTierZero) @@ -164,24 +172,25 @@ func TestCreateGPOAffectedIntermediariesListDelegateTierZero(t *testing.T) { users, err := adAnalysis.CreateGPOAffectedIntermediariesListDelegate(adAnalysis.SelectGPOTierZeroCandidateFilter)(tx, harness.GPOEnforcement.GPOEnforced, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, users.Len()) users, err = adAnalysis.CreateGPOAffectedIntermediariesListDelegate(adAnalysis.SelectGPOTierZeroCandidateFilter)(tx, harness.GPOEnforcement.GPOUnenforced, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, users.Len()) }) } func TestFetchComputerSessionPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Session.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { sessions, err := adAnalysis.FetchComputerSessionPaths(tx, harness.Session.ComputerA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := sessions.AllNodes().IDs() require.Equal(t, 2, len(nodes)) require.Contains(t, nodes, harness.Session.ComputerA.ID) @@ -190,63 +199,67 @@ func TestFetchComputerSessionPaths(t *testing.T) { } func TestFetchComputerSessions(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Session.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { sessions, err := adAnalysis.FetchComputerSessions(tx, harness.Session.ComputerA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, sessions.Len()) }) } func TestFetchGroupSessionPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Session.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { computers, err := adAnalysis.FetchGroupSessionPaths(tx, harness.Session.GroupA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := computers.AllNodes().IDs() require.Equal(t, 4, len(nodes)) nestedComputers, err := adAnalysis.FetchGroupSessionPaths(tx, harness.Session.GroupC) - require.Nil(t, err) + test.RequireNilErr(t, err) nestedNodes := nestedComputers.AllNodes().IDs() require.Equal(t, 5, len(nestedNodes)) }) } func TestFetchGroupSessions(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Session.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { computers, err := adAnalysis.FetchGroupSessions(tx, harness.Session.GroupA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, computers.Len()) require.Equal(t, 2, computers.ContainingNodeKinds(ad.Computer).Len()) nestedComputers, err := adAnalysis.FetchGroupSessions(tx, harness.Session.GroupC, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, nestedComputers.Len()) require.Equal(t, 2, nestedComputers.ContainingNodeKinds(ad.Computer).Len()) }) } func TestFetchUserSessionPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Session.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { sessions, err := adAnalysis.FetchUserSessionPaths(tx, harness.Session.User) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := sessions.AllNodes().IDs() require.Equal(t, 3, len(nodes)) require.Contains(t, nodes, harness.Session.User.ID) @@ -256,26 +269,28 @@ func TestFetchUserSessionPaths(t *testing.T) { } func TestFetchUserSessions(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Session.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { computers, err := adAnalysis.FetchUserSessions(tx, harness.Session.User, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, computers.Len()) require.Equal(t, 2, computers.ContainingNodeKinds(ad.Computer).Len()) }) } func TestCreateOutboundLocalGroupPathDelegateUser(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { path, err := adAnalysis.CreateOutboundLocalGroupPathDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.UserB) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.UserB.ID) require.Contains(t, nodes, harness.LocalGroupSQL.GroupA.ID) @@ -285,26 +300,28 @@ func TestCreateOutboundLocalGroupPathDelegateUser(t *testing.T) { } func TestCreateOutboundLocalGroupListDelegateUser(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { computers, err := adAnalysis.CreateOutboundLocalGroupListDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.UserB, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, computers.Len()) require.Equal(t, harness.LocalGroupSQL.ComputerA.ID, computers.Slice()[0].ID) }) } func TestCreateOutboundLocalGroupPathDelegateGroup(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { path, err := adAnalysis.CreateOutboundLocalGroupPathDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.GroupA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.GroupA.ID) require.Contains(t, nodes, harness.LocalGroupSQL.ComputerA.ID) @@ -313,26 +330,28 @@ func TestCreateOutboundLocalGroupPathDelegateGroup(t *testing.T) { } func TestCreateOutboundLocalGroupListDelegateGroup(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { computers, err := adAnalysis.CreateOutboundLocalGroupListDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.GroupA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, computers.Len()) require.Equal(t, harness.LocalGroupSQL.ComputerA.ID, computers.Slice()[0].ID) }) } func TestCreateOutboundLocalGroupPathDelegateComputer(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { path, err := adAnalysis.CreateOutboundLocalGroupPathDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.ComputerA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.ComputerA.ID) require.Contains(t, nodes, harness.LocalGroupSQL.ComputerB.ID) @@ -343,25 +362,27 @@ func TestCreateOutboundLocalGroupPathDelegateComputer(t *testing.T) { } func TestCreateOutboundLocalGroupListDelegateComputer(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { computers, err := adAnalysis.CreateOutboundLocalGroupListDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.ComputerA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, computers.Len()) }) } func TestCreateInboundLocalGroupPathDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { path, err := adAnalysis.CreateInboundLocalGroupPathDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.ComputerA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.UserB.ID) require.Contains(t, nodes, harness.LocalGroupSQL.UserA.ID) @@ -370,7 +391,7 @@ func TestCreateInboundLocalGroupPathDelegate(t *testing.T) { require.Equal(t, 4, len(nodes)) path, err = adAnalysis.CreateInboundLocalGroupPathDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.ComputerC) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.ComputerA.ID) require.Contains(t, nodes, harness.LocalGroupSQL.GroupB.ID) @@ -380,45 +401,47 @@ func TestCreateInboundLocalGroupPathDelegate(t *testing.T) { } func TestCreateInboundLocalGroupListDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { admins, err := adAnalysis.CreateInboundLocalGroupListDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.ComputerA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, admins.Len()) require.Equal(t, 2, admins.ContainingNodeKinds(ad.User).Len()) admins, err = adAnalysis.CreateInboundLocalGroupListDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.ComputerC, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, admins.Len()) require.Equal(t, harness.LocalGroupSQL.ComputerA.ID, admins.Slice()[0].ID) admins, err = adAnalysis.CreateInboundLocalGroupListDelegate(ad.AdminTo)(tx, harness.LocalGroupSQL.ComputerB, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, admins.Len()) require.Equal(t, harness.LocalGroupSQL.ComputerA.ID, admins.Slice()[0].ID) }) } func TestCreateSQLAdminPathDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { path, err := adAnalysis.CreateSQLAdminPathDelegate(graph.DirectionInbound)(tx, harness.LocalGroupSQL.ComputerA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.UserC.ID) require.Contains(t, nodes, harness.LocalGroupSQL.ComputerA.ID) require.Equal(t, 2, len(nodes)) path, err = adAnalysis.CreateSQLAdminPathDelegate(graph.DirectionOutbound)(tx, harness.LocalGroupSQL.UserC) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.UserC.ID) require.Contains(t, nodes, harness.LocalGroupSQL.ComputerA.ID) @@ -427,36 +450,38 @@ func TestCreateSQLAdminPathDelegate(t *testing.T) { } func TestCreateSQLAdminListDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { admins, err := adAnalysis.CreateSQLAdminListDelegate(graph.DirectionInbound)(tx, harness.LocalGroupSQL.ComputerA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, admins.Len()) computers, err := adAnalysis.CreateSQLAdminListDelegate(graph.DirectionOutbound)(tx, harness.LocalGroupSQL.UserC, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, computers.Len()) }) } func TestCreateConstrainedDelegationPathDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { path, err := adAnalysis.CreateConstrainedDelegationPathDelegate(graph.DirectionInbound)(tx, harness.LocalGroupSQL.ComputerA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.UserD.ID) require.Contains(t, nodes, harness.LocalGroupSQL.ComputerA.ID) require.Equal(t, 2, len(nodes)) path, err = adAnalysis.CreateConstrainedDelegationPathDelegate(graph.DirectionOutbound)(tx, harness.LocalGroupSQL.UserD) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = path.AllNodes().IDs() require.Contains(t, nodes, harness.LocalGroupSQL.UserD.ID) require.Contains(t, nodes, harness.LocalGroupSQL.ComputerA.ID) @@ -465,29 +490,31 @@ func TestCreateConstrainedDelegationPathDelegate(t *testing.T) { } func TestCreateConstrainedDelegationListDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.LocalGroupSQL.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { admins, err := adAnalysis.CreateConstrainedDelegationListDelegate(graph.DirectionInbound)(tx, harness.LocalGroupSQL.ComputerA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, admins.Len()) computers, err := adAnalysis.CreateConstrainedDelegationListDelegate(graph.DirectionOutbound)(tx, harness.LocalGroupSQL.UserD, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, computers.Len()) }) } func TestFetchOutboundADEntityControlPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.OutboundControl.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { path, err := adAnalysis.FetchOutboundADEntityControlPaths(context.Background(), db, harness.OutboundControl.Controller) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Equal(t, 7, len(nodes)) require.Contains(t, nodes, harness.OutboundControl.Controller.ID) @@ -497,19 +524,18 @@ func TestFetchOutboundADEntityControlPaths(t *testing.T) { require.Contains(t, nodes, harness.OutboundControl.GroupC.ID) require.Contains(t, nodes, harness.OutboundControl.ComputerA.ID) require.Contains(t, nodes, harness.OutboundControl.ComputerC.ID) - - return nil }) } func TestFetchOutboundADEntityControl(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.OutboundControl.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { control, err := adAnalysis.FetchOutboundADEntityControl(context.Background(), db, harness.OutboundControl.Controller, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 4, control.Len()) ids := control.IDs() @@ -519,21 +545,20 @@ func TestFetchOutboundADEntityControl(t *testing.T) { require.Contains(t, ids, harness.OutboundControl.ComputerC.ID) control, err = adAnalysis.FetchOutboundADEntityControl(context.Background(), db, harness.OutboundControl.ControllerB, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, control.Len()) - - return nil }) } func TestFetchInboundADEntityControllerPaths(t *testing.T) { t.Run("User", func(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.InboundControl.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { path, err := adAnalysis.FetchInboundADEntityControllerPaths(context.Background(), db, harness.InboundControl.ControlledUser) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Equal(t, 5, len(nodes)) @@ -542,18 +567,17 @@ func TestFetchInboundADEntityControllerPaths(t *testing.T) { require.Contains(t, nodes, harness.InboundControl.UserA.ID) require.Contains(t, nodes, harness.InboundControl.GroupB.ID) require.Contains(t, nodes, harness.InboundControl.UserD.ID) - - return nil }) }) t.Run("Group", func(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.InboundControl.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { path, err := adAnalysis.FetchInboundADEntityControllerPaths(context.Background(), db, harness.InboundControl.ControlledGroup) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Equal(t, 7, len(nodes)) @@ -564,20 +588,19 @@ func TestFetchInboundADEntityControllerPaths(t *testing.T) { require.Contains(t, nodes, harness.InboundControl.UserG.ID) require.Contains(t, nodes, harness.InboundControl.GroupD.ID) require.Contains(t, nodes, harness.InboundControl.UserH.ID) - - return nil }) }) } func TestFetchInboundADEntityControllers(t *testing.T) { t.Run("User", func(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.InboundControl.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { control, err := adAnalysis.FetchInboundADEntityControllers(context.Background(), db, harness.InboundControl.ControlledUser, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 4, control.Len()) ids := control.IDs() @@ -587,20 +610,19 @@ func TestFetchInboundADEntityControllers(t *testing.T) { require.Contains(t, ids, harness.InboundControl.GroupA.ID) control, err = adAnalysis.FetchInboundADEntityControllers(context.Background(), db, harness.InboundControl.ControlledUser, 0, 1) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, control.Len()) - - return nil }) }) t.Run("Group", func(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.InboundControl.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { control, err := adAnalysis.FetchInboundADEntityControllers(context.Background(), db, harness.InboundControl.ControlledGroup, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 6, control.Len()) ids := control.IDs() @@ -612,22 +634,21 @@ func TestFetchInboundADEntityControllers(t *testing.T) { require.Contains(t, ids, harness.InboundControl.UserH.ID) control, err = adAnalysis.FetchInboundADEntityControllers(context.Background(), db, harness.InboundControl.ControlledGroup, 0, 1) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, control.Len()) - - return nil }) }) } func TestCreateOUContainedPathDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.OUHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.CreateOUContainedPathDelegate(ad.User)(tx, harness.OUHarness.OUA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 4, len(nodes)) require.Contains(t, nodes, harness.OUHarness.OUA.ID) @@ -636,7 +657,7 @@ func TestCreateOUContainedPathDelegate(t *testing.T) { require.Contains(t, nodes, harness.OUHarness.UserB.ID) paths, err = adAnalysis.CreateOUContainedPathDelegate(ad.User)(tx, harness.OUHarness.OUB) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = paths.AllNodes().IDs() require.Equal(t, 4, len(nodes)) require.Contains(t, nodes, harness.OUHarness.OUB.ID) @@ -647,29 +668,31 @@ func TestCreateOUContainedPathDelegate(t *testing.T) { } func TestCreateOUContainedListDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.OUHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { contained, err := adAnalysis.CreateOUContainedListDelegate(ad.User)(tx, harness.OUHarness.OUA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, contained.Len()) contained, err = adAnalysis.CreateOUContainedListDelegate(ad.User)(tx, harness.OUHarness.OUB, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, contained.Len()) }) } func TestFetchGroupMemberPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.MembershipHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { path, err := adAnalysis.FetchGroupMemberPaths(tx, harness.MembershipHarness.GroupB) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := path.AllNodes().IDs() require.Equal(t, 3, len(nodes)) require.Contains(t, nodes, harness.MembershipHarness.GroupB.ID) @@ -679,13 +702,14 @@ func TestFetchGroupMemberPaths(t *testing.T) { } func TestFetchGroupMembers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.MembershipHarness.Setup(testContext) - }, func(harness integration.HarnessDetails, tx graph.Transaction) { - members, err := adAnalysis.FetchGroupMembers(tx, harness.MembershipHarness.GroupC, 0, 0) + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { + members, err := adAnalysis.FetchGroupMembers(context.Background(), db, harness.MembershipHarness.GroupC, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 5, members.Len()) require.Equal(t, 2, members.ContainingNodeKinds(ad.Computer).Len()) require.Equal(t, 2, members.ContainingNodeKinds(ad.Group).Len()) @@ -694,13 +718,14 @@ func TestFetchGroupMembers(t *testing.T) { } func TestFetchEntityGroupMembershipPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.MembershipHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.FetchEntityGroupMembershipPaths(tx, harness.MembershipHarness.UserA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 4, len(nodes)) require.Contains(t, nodes, harness.MembershipHarness.UserA.ID) @@ -710,46 +735,49 @@ func TestFetchEntityGroupMembershipPaths(t *testing.T) { } func TestFetchEntityGroupMembership(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.MembershipHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { membership, err := adAnalysis.FetchEntityGroupMembership(tx, harness.MembershipHarness.UserA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 3, membership.Len()) }) } func TestCreateForeignEntityMembershipListDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ForeignHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { members, err := adAnalysis.CreateForeignEntityMembershipListDelegate(ad.Group)(tx, harness.ForeignHarness.LocalDomain, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, members.Len()) require.Equal(t, 1, members.ContainingNodeKinds(ad.Group).Len()) members, err = adAnalysis.CreateForeignEntityMembershipListDelegate(ad.User)(tx, harness.ForeignHarness.LocalDomain, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, members.Len()) require.Equal(t, 2, members.ContainingNodeKinds(ad.User).Len()) }) } func TestFetchCollectedDomains(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.TrustDCSync.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { domains, err := adAnalysis.FetchCollectedDomains(tx) - require.Nil(t, err) + test.RequireNilErr(t, err) for _, domain := range domains { collected, err := domain.Properties.Get(common.Collected.String()).Bool() - require.Nil(t, err) + test.RequireNilErr(t, err) require.True(t, collected) } require.Equal(t, harness.NumCollectedActiveDirectoryDomains, domains.Len()) @@ -758,14 +786,15 @@ func TestFetchCollectedDomains(t *testing.T) { } func TestCreateDomainTrustPathDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.TrustDCSync.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.CreateDomainTrustPathDelegate(graph.DirectionOutbound)(tx, harness.TrustDCSync.DomainA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 4, len(nodes)) require.Contains(t, nodes, harness.TrustDCSync.DomainA.ID) @@ -775,7 +804,7 @@ func TestCreateDomainTrustPathDelegate(t *testing.T) { paths, err = adAnalysis.CreateDomainTrustPathDelegate(graph.DirectionInbound)(tx, harness.TrustDCSync.DomainA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = paths.AllNodes().IDs() require.Equal(t, 3, len(nodes)) require.Contains(t, nodes, harness.TrustDCSync.DomainA.ID) @@ -785,14 +814,15 @@ func TestCreateDomainTrustPathDelegate(t *testing.T) { } func TestCreateDomainTrustListDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.TrustDCSync.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { domains, err := adAnalysis.CreateDomainTrustListDelegate(graph.DirectionOutbound)(tx, harness.TrustDCSync.DomainA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 3, domains.Len()) ids := domains.IDs() require.Contains(t, ids, harness.TrustDCSync.DomainB.ID) @@ -801,7 +831,7 @@ func TestCreateDomainTrustListDelegate(t *testing.T) { domains, err = adAnalysis.CreateDomainTrustListDelegate(graph.DirectionInbound)(tx, harness.TrustDCSync.DomainA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, domains.Len()) ids = domains.IDs() require.Contains(t, ids, harness.TrustDCSync.DomainB.ID) @@ -810,15 +840,16 @@ func TestCreateDomainTrustListDelegate(t *testing.T) { } func TestGetDCSyncers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) // XXX: Why does this need a WriteTransaction to run? - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.TrustDCSync.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { dcSyncers, err := analysis.GetDCSyncers(tx, harness.TrustDCSync.DomainA, true) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, len(dcSyncers)) ids := make([]graph.ID, len(dcSyncers)) for _, node := range dcSyncers { @@ -832,7 +863,7 @@ func TestGetDCSyncers(t *testing.T) { dcSyncers, err = analysis.GetDCSyncers(tx, harness.TrustDCSync.DomainA, true) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, len(dcSyncers)) ids = make([]graph.ID, len(dcSyncers)) for _, node := range dcSyncers { @@ -844,14 +875,15 @@ func TestGetDCSyncers(t *testing.T) { } func TestFetchDCSyncers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.TrustDCSync.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { dcSyncers, err := adAnalysis.FetchDCSyncers(tx, harness.TrustDCSync.DomainA, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, dcSyncers.Len()) nodes := dcSyncers.IDs() @@ -861,14 +893,15 @@ func TestFetchDCSyncers(t *testing.T) { } func TestFetchDCSyncerPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.TrustDCSync.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.FetchDCSyncerPaths(tx, harness.TrustDCSync.DomainA) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 5, len(nodes)) require.Contains(t, nodes, harness.TrustDCSync.DomainA.ID) @@ -880,14 +913,15 @@ func TestFetchDCSyncerPaths(t *testing.T) { } func TestCreateForeignEntityMembershipPathDelegate(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.WriteTransactionTest(func(harness *integration.HarnessDetails) { + testContext.WriteTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ForeignHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.CreateForeignEntityMembershipPathDelegate(ad.Group)(tx, harness.ForeignHarness.LocalDomain) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 2, len(nodes)) require.Contains(t, nodes, harness.ForeignHarness.ForeignGroup.ID) @@ -895,7 +929,7 @@ func TestCreateForeignEntityMembershipPathDelegate(t *testing.T) { paths, err = adAnalysis.CreateForeignEntityMembershipPathDelegate(ad.User)(tx, harness.ForeignHarness.LocalDomain) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes = paths.AllNodes().IDs() require.Equal(t, 4, len(nodes)) require.Contains(t, nodes, harness.ForeignHarness.ForeignGroup.ID) @@ -906,28 +940,30 @@ func TestCreateForeignEntityMembershipPathDelegate(t *testing.T) { } func TestFetchForeignAdmins(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ForeignHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { admins, err := adAnalysis.FetchForeignAdmins(tx, harness.ForeignHarness.LocalDomain, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, admins.Len()) require.Equal(t, 2, admins.ContainingNodeKinds(ad.User).Len()) }) } func TestFetchForeignAdminPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ForeignHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.FetchForeignAdminPaths(tx, harness.ForeignHarness.LocalDomain) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 5, len(nodes)) require.Contains(t, nodes, harness.ForeignHarness.LocalComputer.ID) @@ -939,14 +975,15 @@ func TestFetchForeignAdminPaths(t *testing.T) { } func TestFetchForeignGPOControllers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ForeignHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { admins, err := adAnalysis.FetchForeignGPOControllers(tx, harness.ForeignHarness.LocalDomain, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, admins.Len()) require.Equal(t, 1, admins.ContainingNodeKinds(ad.User).Len()) require.Equal(t, 1, admins.ContainingNodeKinds(ad.Group).Len()) @@ -954,14 +991,15 @@ func TestFetchForeignGPOControllers(t *testing.T) { } func TestFetchForeignGPOControllerPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ForeignHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.FetchForeignGPOControllerPaths(tx, harness.ForeignHarness.LocalDomain) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 3, len(nodes)) require.Contains(t, nodes, harness.ForeignHarness.ForeignUserA.ID) @@ -971,45 +1009,48 @@ func TestFetchForeignGPOControllerPaths(t *testing.T) { } func TestFetchAllEnforcedGPOs(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) - }, func(harness integration.HarnessDetails, tx graph.Transaction) { - gpos, err := adAnalysis.FetchAllEnforcedGPOs(tx, graph.NewNodeSet(harness.GPOEnforcement.OrganizationalUnitD)) + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { + gpos, err := adAnalysis.FetchAllEnforcedGPOs(context.Background(), db, graph.NewNodeSet(harness.GPOEnforcement.OrganizationalUnitD)) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, gpos.Len()) - gpos, err = adAnalysis.FetchAllEnforcedGPOs(tx, graph.NewNodeSet(harness.GPOEnforcement.OrganizationalUnitC)) + gpos, err = adAnalysis.FetchAllEnforcedGPOs(context.Background(), db, graph.NewNodeSet(harness.GPOEnforcement.OrganizationalUnitC)) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 1, gpos.Len()) }) } func TestFetchEntityLinkedGPOList(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { gpos, err := adAnalysis.FetchEntityLinkedGPOList(tx, harness.GPOEnforcement.Domain, 0, 0) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, 2, gpos.Len()) }) } func TestFetchEntityLinkedGPOPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.GPOEnforcement.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := adAnalysis.FetchEntityLinkedGPOPaths(tx, harness.GPOEnforcement.Domain) - require.Nil(t, err) + test.RequireNilErr(t, err) nodes := paths.AllNodes().IDs() require.Equal(t, 3, len(nodes)) require.Contains(t, nodes, harness.GPOEnforcement.Domain.ID) @@ -1019,27 +1060,29 @@ func TestFetchEntityLinkedGPOPaths(t *testing.T) { } func TestFetchLocalGroupCompleteness(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Completeness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { completeness, err := adAnalysis.FetchLocalGroupCompleteness(tx, harness.Completeness.DomainSid) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, .5, completeness) }) } func TestFetchUserSessionCompleteness(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.Completeness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { completeness, err := adAnalysis.FetchUserSessionCompleteness(tx, harness.Completeness.DomainSid) - require.Nil(t, err) + test.RequireNilErr(t, err) require.Equal(t, .5, completeness) }) } diff --git a/cmd/api/src/analysis/ad/adcs_integration_test.go b/cmd/api/src/analysis/ad/adcs_integration_test.go index 8435423a89..20040316f7 100644 --- a/cmd/api/src/analysis/ad/adcs_integration_test.go +++ b/cmd/api/src/analysis/ad/adcs_integration_test.go @@ -22,6 +22,7 @@ package ad_test import ( "context" "github.com/specterops/bloodhound/analysis" + "github.com/specterops/bloodhound/graphschema" ad2 "github.com/specterops/bloodhound/analysis/ad" @@ -39,11 +40,12 @@ import ( ) func TestADCSESC1(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ADCSESC1Harness.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - ESC1") groupExpansions, err := ad2.ExpandAllRDPLocalGroups(context.Background(), db) @@ -106,17 +108,16 @@ func TestADCSESC1(t *testing.T) { } return nil }) - return nil }) - } func TestGoldenCert(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ADCSGoldenCertHarness.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - Golden Cert") domains, err := ad2.FetchNodesByKind(context.Background(), db, ad.Domain) @@ -170,16 +171,17 @@ func TestGoldenCert(t *testing.T) { } return nil }) - return nil }) } func TestCanAbuseUPNCertMapping(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) + + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.WeakCertBindingAndUPNCertMappingHarness.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - CanAbuseUPNCertMapping") if enterpriseCertAuthorities, err := ad2.FetchNodesByKind(context.Background(), db, ad.EnterpriseCA); err != nil { @@ -188,6 +190,7 @@ func TestCanAbuseUPNCertMapping(t *testing.T) { t.Logf("failed post processing for %s: %v", ad.CanAbuseUPNCertMapping.String(), err) } + // TODO: We're throwing away the collected errors from the operation and should assert on them operation.Done() db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { @@ -214,15 +217,15 @@ func TestCanAbuseUPNCertMapping(t *testing.T) { } return nil }) - return nil }) } func TestCanAbuseWeakCertBinding(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.WeakCertBindingAndUPNCertMappingHarness.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - CanAbuseWeakCertBinding") if enterpriseCertAuthorities, err := ad2.FetchNodesByKind(context.Background(), db, ad.EnterpriseCA); err != nil { @@ -231,6 +234,7 @@ func TestCanAbuseWeakCertBinding(t *testing.T) { t.Logf("failed post processing for %s: %v", ad.CanAbuseWeakCertBinding.String(), err) } + // TODO: We're throwing away the collected errors from the operation and should assert on them operation.Done() db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { @@ -255,17 +259,18 @@ func TestCanAbuseWeakCertBinding(t *testing.T) { assert.False(t, results.Contains(harness.WeakCertBindingAndUPNCertMappingHarness.Domain2)) assert.False(t, results.Contains(harness.WeakCertBindingAndUPNCertMappingHarness.Domain3)) } + return nil }) - return nil }) } func TestIssuedSignedBy(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.IssuedSignedByHarness.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - IssuedSignedBy") if rootCertAuthorities, err := ad2.FetchNodesByKind(context.Background(), db, ad.RootCA); err != nil { @@ -322,20 +327,21 @@ func TestIssuedSignedBy(t *testing.T) { assert.False(t, results2.Contains(harness.IssuedSignedByHarness.EnterpriseCA3)) assert.False(t, results3.Contains(harness.IssuedSignedByHarness.EnterpriseCA3)) } + return nil }) - return nil }) } func TestTrustedForNTAuth(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup( - func(harness *integration.HarnessDetails) { + func(harness *integration.HarnessDetails) error { harness.TrustedForNTAuthHarness.Setup(testContext) + return nil }, - func(harness integration.HarnessDetails, db graph.Database) error { + func(harness integration.HarnessDetails, db graph.Database) { // post `TrustedForNTAuth` edges operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - TrustedForNTAuth") @@ -364,16 +370,15 @@ func TestTrustedForNTAuth(t *testing.T) { } return nil }) - - return nil }) } func TestEnrollOnBehalfOf(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.EnrollOnBehalfOfHarnessOne.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate) v1Templates := make([]*graph.Node, 0) v2Templates := make([]*graph.Node, 0) @@ -386,7 +391,9 @@ func TestEnrollOnBehalfOf(t *testing.T) { v2Templates = append(v2Templates, template) } } + require.Nil(t, err) + db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { results, err := ad2.EnrollOnBehalfOfVersionOne(tx, v1Templates, certTemplates) require.Nil(t, err) @@ -413,16 +420,16 @@ func TestEnrollOnBehalfOf(t *testing.T) { return nil }) - - return nil }) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.EnrollOnBehalfOfHarnessTwo.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { certTemplates, err := ad2.FetchNodesByKind(context.Background(), db, ad.CertTemplate) v1Templates := make([]*graph.Node, 0) v2Templates := make([]*graph.Node, 0) + for _, template := range certTemplates { if version, err := template.Properties.Get(ad.SchemaVersion.String()).Float64(); err != nil { continue @@ -432,7 +439,9 @@ func TestEnrollOnBehalfOf(t *testing.T) { v2Templates = append(v2Templates, template) } } + require.Nil(t, err) + db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { results, err := ad2.EnrollOnBehalfOfVersionTwo(tx, v2Templates, certTemplates) require.Nil(t, err) @@ -446,16 +455,15 @@ func TestEnrollOnBehalfOf(t *testing.T) { return nil }) - - return nil }) } func TestADCSESC3(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, graphschema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ESC3Harness1.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - ESC3") groupExpansions, err := ad2.ExpandAllRDPLocalGroups(context.Background(), db) @@ -506,12 +514,12 @@ func TestADCSESC3(t *testing.T) { } return nil }) - return nil }) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ESC3Harness2.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { operation := analysis.NewPostRelationshipOperation(context.Background(), db, "ADCS Post Process Test - ESC3") groupExpansions, err := ad2.ExpandAllRDPLocalGroups(context.Background(), db) @@ -570,6 +578,5 @@ func TestADCSESC3(t *testing.T) { } return nil }) - return nil }) } diff --git a/cmd/api/src/analysis/ad/tierzero.go b/cmd/api/src/analysis/ad/tierzero.go deleted file mode 100644 index a3ca714827..0000000000 --- a/cmd/api/src/analysis/ad/tierzero.go +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2023 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package ad - -import ( - analysis "github.com/specterops/bloodhound/analysis/ad" - "github.com/specterops/bloodhound/dawgs/graph" - "github.com/specterops/bloodhound/dawgs/ops" - "github.com/specterops/bloodhound/dawgs/query" - "github.com/specterops/bloodhound/graphschema/ad" - "github.com/specterops/bloodhound/graphschema/common" -) - -func TierZeroWellKnownSIDSuffixes() []string { - return []string{ - analysis.EnterpriseDomainControllersGroupSIDSuffix, - analysis.AdministratorAccountSIDSuffix, - analysis.DomainAdminsGroupSIDSuffix, - analysis.DomainControllersGroupSIDSuffix, - analysis.SchemaAdminsGroupSIDSuffix, - analysis.EnterpriseAdminsGroupSIDSuffix, - analysis.KeyAdminsGroupSIDSuffix, - analysis.EnterpriseKeyAdminsGroupSIDSuffix, - analysis.BackupOperatorsGroupSIDSuffix, - analysis.AdministratorsGroupSIDSuffix, - } -} - -func FetchWellKnownTierZeroEntities(tx graph.Transaction, domainSID string) (graph.NodeSet, error) { - nodes := graph.NewNodeSet() - - for _, wellKnownSIDSuffix := range TierZeroWellKnownSIDSuffixes() { - if err := tx.Nodes().Filterf(func() graph.Criteria { - return query.And( - // Make sure we have the Group or User label. This should cover the case for URA as well as filter out all the other localgroups - query.KindIn(query.Node(), ad.Group, ad.User), - query.StringEndsWith(query.NodeProperty(common.ObjectID.String()), wellKnownSIDSuffix), - query.Equals(query.NodeProperty(ad.DomainSID.String()), domainSID), - ) - }).Fetch(func(cursor graph.Cursor[*graph.Node]) error { - for node := range cursor.Chan() { - nodes.Add(node) - } - - return cursor.Error() - }); err != nil { - return nil, err - } - } - - return nodes, nil -} - -func FetchAllGroupMembers(tx graph.Transaction, targets graph.NodeSet) (graph.NodeSet, error) { - allGroupMembers := graph.NewNodeSet() - - for _, target := range targets { - if target.Kinds.ContainsOneOf(ad.Group) { - if groupMembers, err := analysis.FetchGroupMembers(tx, target, 0, 0); err != nil { - return nil, err - } else { - allGroupMembers.AddSet(groupMembers) - } - } - } - - return allGroupMembers, nil -} - -func FetchDomainTierZeroAssets(tx graph.Transaction, domain *graph.Node) (graph.NodeSet, error) { - domainSID, _ := domain.Properties.GetOrDefault(ad.DomainSID.String(), "").String() - - return ops.FetchNodeSet(tx.Nodes().Filterf(func() graph.Criteria { - return query.And( - query.Kind(query.Node(), ad.Entity), - query.Equals(query.NodeProperty(ad.DomainSID.String()), domainSID), - query.StringContains(query.NodeProperty(common.SystemTags.String()), ad.AdminTierZero), - ) - })) -} diff --git a/cmd/api/src/analysis/analysis_integration_test.go b/cmd/api/src/analysis/analysis_integration_test.go index 35261c8225..b07da81ff3 100644 --- a/cmd/api/src/analysis/analysis_integration_test.go +++ b/cmd/api/src/analysis/analysis_integration_test.go @@ -21,6 +21,8 @@ package analysis_test import ( "context" + schema "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/src/test" "testing" analysis "github.com/specterops/bloodhound/analysis/ad" @@ -32,10 +34,11 @@ import ( ) func TestFetchRDPEnsureNoDescent(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.RDPB.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { groupExpansions, err := analysis.ExpandAllRDPLocalGroups(context.Background(), db) require.Nil(t, err) @@ -50,16 +53,15 @@ func TestFetchRDPEnsureNoDescent(t *testing.T) { return nil })) - - return nil }) } func TestFetchRDPEntityBitmapForComputer(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.RDP.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { groupExpansions, err := analysis.ExpandAllRDPLocalGroups(context.Background(), db) require.Nil(t, err) @@ -116,7 +118,7 @@ func TestFetchRDPEntityBitmapForComputer(t *testing.T) { // Create a RemoteInteractiveLogonPrivilege relationship from the RDP local group to the computer to test our most common case require.Nil(t, db.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - _, err := tx.CreateRelationship(harness.RDP.RDPLocalGroup, harness.RDP.Computer, ad.RemoteInteractiveLogonPrivilege, graph.NewProperties()) + _, err := tx.CreateRelationshipByIDs(harness.RDP.RDPLocalGroup.ID, harness.RDP.Computer.ID, ad.RemoteInteractiveLogonPrivilege, graph.NewProperties()) return err })) @@ -124,7 +126,7 @@ func TestFetchRDPEntityBitmapForComputer(t *testing.T) { groupExpansions, err = analysis.ExpandAllRDPLocalGroups(context.Background(), db) require.Nil(t, err) - return db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { + test.RequireNilErr(t, db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { rdpEnabledEntityIDBitmap, err := analysis.FetchRDPEntityBitmapForComputer(tx, harness.RDP.Computer.ID, groupExpansions, true) require.Nil(t, err) @@ -138,6 +140,6 @@ func TestFetchRDPEntityBitmapForComputer(t *testing.T) { require.True(t, rdpEnabledEntityIDBitmap.Contains(harness.RDP.DomainGroupA.ID.Uint32())) return nil - }) + })) }) } diff --git a/cmd/api/src/analysis/azure/azure_integration_test.go b/cmd/api/src/analysis/azure/azure_integration_test.go index bbee5cb7dc..66b64fc33d 100644 --- a/cmd/api/src/analysis/azure/azure_integration_test.go +++ b/cmd/api/src/analysis/azure/azure_integration_test.go @@ -20,6 +20,7 @@ package azure_test import ( "context" + schema "github.com/specterops/bloodhound/graphschema" "sort" "testing" @@ -45,9 +46,10 @@ func SortIDs(ids []graph.ID) []graph.ID { } func TestFetchEntityByObjectID(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZBaseHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { node, err := azureanalysis.FetchEntityByObjectID(tx, testContext.NodeObjectID(harness.AZBaseHarness.Application)) @@ -57,9 +59,10 @@ func TestFetchEntityByObjectID(t *testing.T) { } func TestEntityRoles(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZBaseHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { roles, err := azureanalysis.FetchEntityRoles(tx, harness.AZBaseHarness.User, 0, 0) @@ -69,9 +72,10 @@ func TestEntityRoles(t *testing.T) { } func TestTraverseNodePaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZBaseHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { // Preform a full traversal of all outbound paths from the user node if paths, err := ops.TraversePaths(tx, ops.TraversalPlan{ @@ -108,9 +112,10 @@ func TestTraverseNodePaths(t *testing.T) { } func TestAzureEntityRoles(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZBaseHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if roles, err := azureanalysis.FetchEntityRoles(tx, harness.AZBaseHarness.User, 0, 0); err != nil { t.Fatal(err) @@ -121,9 +126,10 @@ func TestAzureEntityRoles(t *testing.T) { } func TestAzureEntityGroupMembership(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZBaseHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if groupPaths, err := azureanalysis.FetchEntityGroupMembershipPaths(tx, harness.AZBaseHarness.User); err != nil { t.Fatal(err) @@ -134,9 +140,10 @@ func TestAzureEntityGroupMembership(t *testing.T) { } func TestAZMGApplicationReadWriteAll(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZMGApplicationReadWriteAllHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if outboundAbusableAppRoleAssignments, err := azureanalysis.FetchAbusableAppRoleAssignments(tx, harness.AZMGApplicationReadWriteAllHarness.ServicePrincipal, graph.DirectionOutbound, 0, 0); err != nil { @@ -191,9 +198,10 @@ func TestAZMGApplicationReadWriteAll(t *testing.T) { } func TestAZMGAppRoleManagementReadWriteAll(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZMGAppRoleManagementReadWriteAllHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if outboundAbusableAppRoleAssignments, err := azureanalysis.FetchAbusableAppRoleAssignments(tx, harness.AZMGAppRoleManagementReadWriteAllHarness.ServicePrincipal, graph.DirectionOutbound, 0, 0); err != nil { @@ -234,9 +242,10 @@ func TestAZMGAppRoleManagementReadWriteAll(t *testing.T) { } func TestAZMGDirectoryReadWriteAll(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZMGDirectoryReadWriteAllHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if outboundAbusableAppRoleAssignments, err := azureanalysis.FetchAbusableAppRoleAssignments(tx, harness.AZMGDirectoryReadWriteAllHarness.ServicePrincipal, graph.DirectionOutbound, 0, 0); err != nil { @@ -277,9 +286,10 @@ func TestAZMGDirectoryReadWriteAll(t *testing.T) { } func TestAZMGGroupReadWriteAll(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZMGGroupReadWriteAllHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if outboundAbusableAppRoleAssignments, err := azureanalysis.FetchAbusableAppRoleAssignments(tx, harness.AZMGGroupReadWriteAllHarness.ServicePrincipal, graph.DirectionOutbound, 0, 0); err != nil { @@ -320,9 +330,10 @@ func TestAZMGGroupReadWriteAll(t *testing.T) { } func TestAZMGGroupMemberReadWriteAll(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZMGGroupMemberReadWriteAllHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if outboundAbusableAppRoleAssignments, err := azureanalysis.FetchAbusableAppRoleAssignments(tx, harness.AZMGGroupMemberReadWriteAllHarness.ServicePrincipal, graph.DirectionOutbound, 0, 0); err != nil { @@ -363,9 +374,10 @@ func TestAZMGGroupMemberReadWriteAll(t *testing.T) { } func TestAZMGRoleManagementReadWriteDirectory(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZMGRoleManagementReadWriteDirectoryHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if outboundAbusableAppRoleAssignments, err := azureanalysis.FetchAbusableAppRoleAssignments(tx, harness.AZMGRoleManagementReadWriteDirectoryHarness.ServicePrincipal, graph.DirectionOutbound, 0, 0); err != nil { @@ -430,9 +442,10 @@ func TestAZMGRoleManagementReadWriteDirectory(t *testing.T) { } func TestAZMGServicePrincipalEndpointReadWriteAll(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZMGServicePrincipalEndpointReadWriteAllHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { if outboundAbusableAppRoleAssignments, err := azureanalysis.FetchAbusableAppRoleAssignments(tx, harness.AZMGServicePrincipalEndpointReadWriteAllHarness.ServicePrincipal, graph.DirectionOutbound, 0, 0); err != nil { @@ -477,22 +490,23 @@ func TestAZMGServicePrincipalEndpointReadWriteAll(t *testing.T) { **********************/ func TestApplicationEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { appObjectID, err := harness.AZEntityPanelHarness.Application.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", appObjectID) - app, err := azureanalysis.ApplicationEntityDetails(context.Background(), testContext.GraphDB, appObjectID, false) + app, err := azureanalysis.ApplicationEntityDetails(context.Background(), testContext.Graph.Database, appObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.Application.Properties.Get(common.ObjectID.String()).Any(), app.Properties[common.ObjectID.String()]) assert.Equal(t, 0, app.InboundObjectControl) - app, err = azureanalysis.ApplicationEntityDetails(context.Background(), testContext.GraphDB, appObjectID, true) + app, err = azureanalysis.ApplicationEntityDetails(context.Background(), testContext.Graph.Database, appObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, app.InboundObjectControl) @@ -500,22 +514,23 @@ func TestApplicationEntityDetails(t *testing.T) { } func TestDeviceEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { deviceObjectID, err := harness.AZEntityPanelHarness.Device.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", deviceObjectID) - device, err := azureanalysis.DeviceEntityDetails(context.Background(), testContext.GraphDB, deviceObjectID, false) + device, err := azureanalysis.DeviceEntityDetails(context.Background(), testContext.Graph.Database, deviceObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.Device.Properties.Get(common.ObjectID.String()).Any(), device.Properties[common.ObjectID.String()]) assert.Equal(t, 0, device.InboundObjectControl) - device, err = azureanalysis.DeviceEntityDetails(context.Background(), testContext.GraphDB, deviceObjectID, true) + device, err = azureanalysis.DeviceEntityDetails(context.Background(), testContext.Graph.Database, deviceObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, device.InboundObjectControl) @@ -523,22 +538,23 @@ func TestDeviceEntityDetails(t *testing.T) { } func TestGroupEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { groupObjectID, err := harness.AZEntityPanelHarness.Group.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", groupObjectID) - group, err := azureanalysis.GroupEntityDetails(testContext.GraphDB, groupObjectID, false) + group, err := azureanalysis.GroupEntityDetails(testContext.Graph.Database, groupObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.Group.Properties.Get(common.ObjectID.String()).Any(), group.Properties[common.ObjectID.String()]) assert.Equal(t, 0, group.InboundObjectControl) - group, err = azureanalysis.GroupEntityDetails(testContext.GraphDB, groupObjectID, true) + group, err = azureanalysis.GroupEntityDetails(testContext.Graph.Database, groupObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, group.InboundObjectControl) @@ -546,22 +562,23 @@ func TestGroupEntityDetails(t *testing.T) { } func TestManagementGroupEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { groupObjectID, err := harness.AZEntityPanelHarness.ManagementGroup.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", groupObjectID) - group, err := azureanalysis.ManagementGroupEntityDetails(context.Background(), testContext.GraphDB, groupObjectID, false) + group, err := azureanalysis.ManagementGroupEntityDetails(context.Background(), testContext.Graph.Database, groupObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.ManagementGroup.Properties.Get(common.ObjectID.String()).Any(), group.Properties[common.ObjectID.String()]) assert.Equal(t, 0, group.InboundObjectControl) - group, err = azureanalysis.ManagementGroupEntityDetails(context.Background(), testContext.GraphDB, groupObjectID, true) + group, err = azureanalysis.ManagementGroupEntityDetails(context.Background(), testContext.Graph.Database, groupObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, group.InboundObjectControl) @@ -569,22 +586,23 @@ func TestManagementGroupEntityDetails(t *testing.T) { } func TestResourceGroupEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { groupObjectID, err := harness.AZEntityPanelHarness.ResourceGroup.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", groupObjectID) - group, err := azureanalysis.ResourceGroupEntityDetails(context.Background(), testContext.GraphDB, groupObjectID, false) + group, err := azureanalysis.ResourceGroupEntityDetails(context.Background(), testContext.Graph.Database, groupObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.ResourceGroup.Properties.Get(common.ObjectID.String()).Any(), group.Properties[common.ObjectID.String()]) assert.Equal(t, 0, group.InboundObjectControl) - group, err = azureanalysis.ResourceGroupEntityDetails(context.Background(), testContext.GraphDB, groupObjectID, true) + group, err = azureanalysis.ResourceGroupEntityDetails(context.Background(), testContext.Graph.Database, groupObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, group.InboundObjectControl) @@ -592,22 +610,23 @@ func TestResourceGroupEntityDetails(t *testing.T) { } func TestKeyVaultEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { keyVaultObjectID, err := harness.AZEntityPanelHarness.KeyVault.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", keyVaultObjectID) - keyVault, err := azureanalysis.KeyVaultEntityDetails(context.Background(), testContext.GraphDB, keyVaultObjectID, false) + keyVault, err := azureanalysis.KeyVaultEntityDetails(context.Background(), testContext.Graph.Database, keyVaultObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.KeyVault.Properties.Get(common.ObjectID.String()).Any(), keyVault.Properties[common.ObjectID.String()]) assert.Equal(t, 0, keyVault.InboundObjectControl) - keyVault, err = azureanalysis.KeyVaultEntityDetails(context.Background(), testContext.GraphDB, keyVaultObjectID, true) + keyVault, err = azureanalysis.KeyVaultEntityDetails(context.Background(), testContext.Graph.Database, keyVaultObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, keyVault.InboundObjectControl) @@ -615,22 +634,23 @@ func TestKeyVaultEntityDetails(t *testing.T) { } func TestRoleEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { roleObjectID, err := harness.AZEntityPanelHarness.Role.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", roleObjectID) - role, err := azureanalysis.RoleEntityDetails(context.Background(), testContext.GraphDB, roleObjectID, false) + role, err := azureanalysis.RoleEntityDetails(context.Background(), testContext.Graph.Database, roleObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.Role.Properties.Get(common.ObjectID.String()).Any(), role.Properties[common.ObjectID.String()]) assert.Equal(t, 0, role.ActiveAssignments) - role, err = azureanalysis.RoleEntityDetails(context.Background(), testContext.GraphDB, roleObjectID, true) + role, err = azureanalysis.RoleEntityDetails(context.Background(), testContext.Graph.Database, roleObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, role.ActiveAssignments) @@ -638,22 +658,23 @@ func TestRoleEntityDetails(t *testing.T) { } func TestServicePrincipalEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { servicePrincipalObjectID, err := harness.AZEntityPanelHarness.ServicePrincipal.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", servicePrincipalObjectID) - servicePrincipal, err := azureanalysis.ServicePrincipalEntityDetails(context.Background(), testContext.GraphDB, servicePrincipalObjectID, false) + servicePrincipal, err := azureanalysis.ServicePrincipalEntityDetails(context.Background(), testContext.Graph.Database, servicePrincipalObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.ServicePrincipal.Properties.Get(common.ObjectID.String()).Any(), servicePrincipal.Properties[common.ObjectID.String()]) assert.Equal(t, 0, servicePrincipal.InboundObjectControl) - servicePrincipal, err = azureanalysis.ServicePrincipalEntityDetails(context.Background(), testContext.GraphDB, servicePrincipalObjectID, true) + servicePrincipal, err = azureanalysis.ServicePrincipalEntityDetails(context.Background(), testContext.Graph.Database, servicePrincipalObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, servicePrincipal.InboundObjectControl) @@ -661,22 +682,23 @@ func TestServicePrincipalEntityDetails(t *testing.T) { } func TestSubscriptionEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { subscriptionObjectID, err := harness.AZEntityPanelHarness.Subscription.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", subscriptionObjectID) - subscription, err := azureanalysis.SubscriptionEntityDetails(context.Background(), testContext.GraphDB, subscriptionObjectID, false) + subscription, err := azureanalysis.SubscriptionEntityDetails(context.Background(), testContext.Graph.Database, subscriptionObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.Subscription.Properties.Get(common.ObjectID.String()).Any(), subscription.Properties[common.ObjectID.String()]) assert.Equal(t, 0, subscription.InboundObjectControl) - subscription, err = azureanalysis.SubscriptionEntityDetails(context.Background(), testContext.GraphDB, subscriptionObjectID, true) + subscription, err = azureanalysis.SubscriptionEntityDetails(context.Background(), testContext.Graph.Database, subscriptionObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, subscription.InboundObjectControl) @@ -684,22 +706,23 @@ func TestSubscriptionEntityDetails(t *testing.T) { } func TestTenantEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { tenantObjectID, err := harness.AZEntityPanelHarness.Tenant.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", tenantObjectID) - tenant, err := azureanalysis.TenantEntityDetails(testContext.GraphDB, tenantObjectID, false) + tenant, err := azureanalysis.TenantEntityDetails(testContext.Graph.Database, tenantObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.Tenant.Properties.Get(common.ObjectID.String()).Any(), tenant.Properties[common.ObjectID.String()]) assert.Equal(t, 0, tenant.InboundObjectControl) - tenant, err = azureanalysis.TenantEntityDetails(testContext.GraphDB, tenantObjectID, true) + tenant, err = azureanalysis.TenantEntityDetails(testContext.Graph.Database, tenantObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, tenant.InboundObjectControl) @@ -707,22 +730,23 @@ func TestTenantEntityDetails(t *testing.T) { } func TestUserEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { userObjectID, err := harness.AZEntityPanelHarness.User.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", userObjectID) - user, err := azureanalysis.UserEntityDetails(testContext.GraphDB, userObjectID, false) + user, err := azureanalysis.UserEntityDetails(testContext.Graph.Database, userObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.User.Properties.Get(common.ObjectID.String()).Any(), user.Properties[common.ObjectID.String()]) assert.Equal(t, 0, user.OutboundObjectControl) - user, err = azureanalysis.UserEntityDetails(testContext.GraphDB, userObjectID, true) + user, err = azureanalysis.UserEntityDetails(testContext.Graph.Database, userObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, user.OutboundObjectControl) @@ -730,22 +754,23 @@ func TestUserEntityDetails(t *testing.T) { } func TestVMEntityDetails(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZEntityPanelHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { vmObjectID, err := harness.AZEntityPanelHarness.VM.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) assert.NotEqual(t, "", vmObjectID) - vm, err := azureanalysis.VMEntityDetails(context.Background(), testContext.GraphDB, vmObjectID, false) + vm, err := azureanalysis.VMEntityDetails(context.Background(), testContext.Graph.Database, vmObjectID, false) require.Nil(t, err) assert.Equal(t, harness.AZEntityPanelHarness.VM.Properties.Get(common.ObjectID.String()).Any(), vm.Properties[common.ObjectID.String()]) assert.Equal(t, 0, vm.InboundObjectControl) - vm, err = azureanalysis.VMEntityDetails(context.Background(), testContext.GraphDB, vmObjectID, true) + vm, err = azureanalysis.VMEntityDetails(context.Background(), testContext.Graph.Database, vmObjectID, true) require.Nil(t, err) assert.NotEqual(t, 0, vm.InboundObjectControl) @@ -753,10 +778,11 @@ func TestVMEntityDetails(t *testing.T) { } func TestFetchInboundEntityObjectControlPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZInboundControlHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { paths, err := azureanalysis.FetchInboundEntityObjectControlPaths(tx, harness.AZInboundControlHarness.ControlledAZUser, graph.DirectionInbound) require.Nil(t, err) @@ -774,10 +800,11 @@ func TestFetchInboundEntityObjectControlPaths(t *testing.T) { } func TestFetchInboundEntityObjectControllers(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) - testContext.ReadTransactionTest(func(harness *integration.HarnessDetails) { + testContext.ReadTransactionTestWithSetup(func(harness *integration.HarnessDetails) error { harness.AZInboundControlHarness.Setup(testContext) + return nil }, func(harness integration.HarnessDetails, tx graph.Transaction) { control, err := azureanalysis.FetchInboundEntityObjectControllers(tx, harness.AZInboundControlHarness.ControlledAZUser, graph.DirectionInbound, 0, 0) require.Nil(t, err) diff --git a/cmd/api/src/analysis/membership_integration_test.go b/cmd/api/src/analysis/membership_integration_test.go index 9fb27eb40b..a80fdf2fe4 100644 --- a/cmd/api/src/analysis/membership_integration_test.go +++ b/cmd/api/src/analysis/membership_integration_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 //go:build integration @@ -21,22 +21,25 @@ package analysis_test import ( "context" + schema "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/src/test" "testing" - "github.com/specterops/bloodhound/src/test/integration" - "github.com/stretchr/testify/require" analysis "github.com/specterops/bloodhound/analysis/ad" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/dawgs/query" "github.com/specterops/bloodhound/graphschema/ad" + "github.com/specterops/bloodhound/src/test/integration" + "github.com/stretchr/testify/require" ) func TestRealizeNodeKindDuplexMap(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.RootADHarness.Setup(testContext) harness.TrustDCSync.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { var ( domainNode = testContext.FindNode(query.Equals(query.NodeProperty("name"), "DomainA")) impactMap, impactErr = analysis.FetchPathMembers(context.Background(), db, domainNode.ID, graph.DirectionInbound) @@ -51,17 +54,16 @@ func TestRealizeNodeKindDuplexMap(t *testing.T) { require.Equal(t, 2, int(impactKindMap.Get(ad.Group).Cardinality())) require.Equal(t, 3, int(impactKindMap.Get(ad.User).Cardinality())) require.Equal(t, 1, int(impactKindMap.Get(ad.GPO).Cardinality())) - - return nil }) } func TestAnalyzeExposure(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.RootADHarness.Setup(testContext) harness.TrustDCSync.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { var ( domainNode = testContext.FindNode(query.Equals(query.NodeProperty("name"), "DomainA")) impactMap, err = analysis.FetchPathMembers(context.Background(), db, domainNode.ID, graph.DirectionInbound) @@ -69,24 +71,23 @@ func TestAnalyzeExposure(t *testing.T) { require.Nil(t, err) require.Equalf(t, 9, int(impactMap.Cardinality()), "Failed to collect expected nodes. Saw IDs: %+v", impactMap.Slice()) - - return nil }) } func TestResolveAllGroupMemberships(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.RDP.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { memberships, err := analysis.ResolveAllGroupMemberships(context.Background(), db) + test.RequireNilErr(t, err) + require.Equal(t, 3, int(memberships.Cardinality(harness.RDP.DomainGroupA.ID.Uint32()).Cardinality())) require.Equal(t, 2, int(memberships.Cardinality(harness.RDP.DomainGroupB.ID.Uint32()).Cardinality())) require.Equal(t, 1, int(memberships.Cardinality(harness.RDP.DomainGroupC.ID.Uint32()).Cardinality())) require.Equal(t, 1, int(memberships.Cardinality(harness.RDP.DomainGroupD.ID.Uint32()).Cardinality())) require.Equal(t, 2, int(memberships.Cardinality(harness.RDP.DomainGroupE.ID.Uint32()).Cardinality())) - - return err }) } diff --git a/cmd/api/src/analysis/post_integration_test.go b/cmd/api/src/analysis/post_integration_test.go index 575da00e04..cd7cc543cf 100644 --- a/cmd/api/src/analysis/post_integration_test.go +++ b/cmd/api/src/analysis/post_integration_test.go @@ -22,6 +22,8 @@ package analysis_test import ( "context" ad2 "github.com/specterops/bloodhound/analysis/ad" + schema "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/src/test" "testing" "github.com/specterops/bloodhound/analysis" @@ -48,54 +50,44 @@ func FetchNumHarnessNodes(db graph.Database) (int64, error) { func TestClearOrphanedNodes(t *testing.T) { const numNodesToCreate = 1000 - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) error { - if numHarnessNodes, err := FetchNumHarnessNodes(db); err != nil { - return err - } else { - if err := db.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - for numCreated := 0; numCreated < numNodesToCreate; numCreated++ { - if _, err := tx.CreateNode(graph.NewProperties(), ad.Entity); err != nil { - return err - } - } - - return nil - }); err != nil { - return err - } - - if numNodesAfterCreation, err := FetchNumHarnessNodes(db); err != nil { - return err - } else { - require.Equal(t, numHarnessNodes+numNodesToCreate, numNodesAfterCreation) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) { + numHarnessNodes, err := FetchNumHarnessNodes(db) + test.RequireNilErr(t, err) - if err := analysis.ClearOrphanedNodes(context.Background(), db); err != nil { + test.RequireNilErr(t, db.WriteTransaction(context.Background(), func(tx graph.Transaction) error { + for numCreated := 0; numCreated < numNodesToCreate; numCreated++ { + if _, err := tx.CreateNode(graph.NewProperties(), ad.Entity); err != nil { return err - } else if numNodesAfterDeletion, err := FetchNumHarnessNodes(db); err != nil { - return err - } else { - require.Equal(t, numHarnessNodes, numNodesAfterDeletion) } } - } - return nil + return nil + })) + + numNodesAfterCreation, err := FetchNumHarnessNodes(db) + test.RequireNilErr(t, err) + + require.Equal(t, numHarnessNodes+numNodesToCreate, numNodesAfterCreation) + test.RequireNilErr(t, analysis.ClearOrphanedNodes(context.Background(), db)) + + numNodesAfterDeletion, err := FetchNumHarnessNodes(db) + test.RequireNilErr(t, err) + require.Equal(t, numHarnessNodes, numNodesAfterDeletion) }) } func TestCrossProduct(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) { + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTestWithSetup(func(harness *integration.HarnessDetails) error { harness.ShortcutHarness.Setup(testContext) - }, func(harness integration.HarnessDetails, db graph.Database) error { + return nil + }, func(harness integration.HarnessDetails, db graph.Database) { firstSet := []*graph.Node{testContext.Harness.ShortcutHarness.Group1} secondSet := []*graph.Node{testContext.Harness.ShortcutHarness.Group2} groupExpansions, err := ad2.ExpandAllRDPLocalGroups(context.Background(), db) require.Nil(t, err) results := ad2.CalculateCrossProductNodeSets(groupExpansions, firstSet, secondSet) require.True(t, results.Contains(harness.ShortcutHarness.Group3.ID.Uint32())) - - return nil }) } diff --git a/cmd/api/src/api/middleware/auth.go b/cmd/api/src/api/middleware/auth.go index 4c71f02768..99ba977a4f 100644 --- a/cmd/api/src/api/middleware/auth.go +++ b/cmd/api/src/api/middleware/auth.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package middleware @@ -99,14 +99,30 @@ func AuthMiddleware(authenticator api.Authenticator) mux.MiddlewareFunc { } } -// PermissionsCheck is a middleware func generator that returns a http.Handler which closes around a list of +// PermissionsCheckAll is a middleware func generator that returns a http.Handler which closes around a list of // permissions that an actor must have in the request auth context to access the wrapped http.Handler. -func PermissionsCheck(authorizer auth.Authorizer, permissions ...model.Permission) mux.MiddlewareFunc { +func PermissionsCheckAll(authorizer auth.Authorizer, permissions ...model.Permission) mux.MiddlewareFunc { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + if bhCtx := ctx.FromRequest(request); !bhCtx.AuthCtx.Authenticated() { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusUnauthorized, "not authenticated", request), response) + } else if !authorizer.AllowsAllPermissions(bhCtx.AuthCtx, permissions) { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, "not authorized", request), response) + } else { + next.ServeHTTP(response, request) + } + }) + } +} + +// PermissionsCheckAtLeastOne is a middleware func generator that returns a http.Handler which closes around a list of +// permissions that an actor must have at least one in the request auth context to access the wrapped http.Handler. +func PermissionsCheckAtLeastOne(authorizer auth.Authorizer, permissions ...model.Permission) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { if bhCtx := ctx.FromRequest(request); !bhCtx.AuthCtx.Authenticated() { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusUnauthorized, "not authenticated", request), response) - } else if !authorizer.AllowsPermissions(bhCtx.AuthCtx, permissions) { + } else if !authorizer.AllowsAtLeastOnePermission(bhCtx.AuthCtx, permissions) { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, "not authorized", request), response) } else { next.ServeHTTP(response, request) diff --git a/cmd/api/src/api/middleware/auth_test.go b/cmd/api/src/api/middleware/auth_test.go index 8c31ede1c8..4bb73cbca9 100644 --- a/cmd/api/src/api/middleware/auth_test.go +++ b/cmd/api/src/api/middleware/auth_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package middleware @@ -21,6 +21,7 @@ import ( "testing" "time" + "github.com/specterops/bloodhound/headers" "github.com/specterops/bloodhound/src/api" "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/ctx" @@ -28,11 +29,14 @@ import ( "github.com/specterops/bloodhound/src/test/must" "github.com/specterops/bloodhound/src/utils/test" "github.com/stretchr/testify/require" - "github.com/specterops/bloodhound/headers" ) -func permissionsCheckHandler(internalHandler http.HandlerFunc, permissions ...model.Permission) http.Handler { - return PermissionsCheck(auth.NewAuthorizer(), permissions...)(internalHandler) +func permissionsCheckAllHandler(internalHandler http.HandlerFunc, permissions ...model.Permission) http.Handler { + return PermissionsCheckAll(auth.NewAuthorizer(), permissions...)(internalHandler) +} + +func permissionsCheckAtLeastOneHandler(internalHandler http.HandlerFunc, permissions ...model.Permission) http.Handler { + return PermissionsCheckAtLeastOne(auth.NewAuthorizer(), permissions...)(internalHandler) } func Test_parseAuthorizationHeader(t *testing.T) { @@ -52,7 +56,7 @@ func Test_parseAuthorizationHeader(t *testing.T) { require.Nil(t, err) } -func TestPermissionsCheck(t *testing.T) { +func TestPermissionsCheckAll(t *testing.T) { var ( handlerReturn200 = func(response http.ResponseWriter, request *http.Request) { response.WriteHeader(http.StatusOK) @@ -63,7 +67,7 @@ func TestPermissionsCheck(t *testing.T) { WithURL("http//example.com"). WithHeader(headers.RequestID.String(), "requestID"). WithContext(&ctx.Context{}). - OnHandler(permissionsCheckHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAllHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). ResponseStatusCode(http.StatusUnauthorized) @@ -83,7 +87,121 @@ func TestPermissionsCheck(t *testing.T) { Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAllHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + Require(). + ResponseStatusCode(http.StatusForbidden) + + test.Request(t). + WithURL("http//example.com"). + WithHeader(headers.RequestID.String(), "requestID"). + WithContext(&ctx.Context{ + AuthCtx: auth.Context{ + PermissionOverrides: auth.PermissionOverrides{}, + Owner: model.User{ + Roles: model.Roles{ + { + Name: "Big Boy", + Description: "The big boy.", + Permissions: auth.Permissions().All(), + }, + }, + }, + Session: model.UserSession{}, + }, + }). + OnHandler(permissionsCheckAllHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + Require(). + ResponseStatusCode(http.StatusOK) +} + +func TestPermissionsCheckAtLeastOne(t *testing.T) { + var ( + handlerReturn200 = func(response http.ResponseWriter, request *http.Request) { + response.WriteHeader(http.StatusOK) + } + ) + + test.Request(t). + WithURL("http//example.com"). + WithContext(&ctx.Context{ + AuthCtx: auth.Context{ + PermissionOverrides: auth.PermissionOverrides{}, + Owner: model.User{ + Roles: model.Roles{ + { + Name: "Big Boy", + Description: "The big boy.", + Permissions: model.Permissions{auth.Permissions().AuthManageSelf}, + }, + }, + }, + Session: model.UserSession{}, + }, + }). + OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + Require(). + ResponseStatusCode(http.StatusOK) + + test.Request(t). + WithURL("http//example.com"). + WithContext(&ctx.Context{ + AuthCtx: auth.Context{ + PermissionOverrides: auth.PermissionOverrides{}, + Owner: model.User{ + Roles: model.Roles{ + { + Name: "Big Boy", + Description: "The big boy.", + Permissions: model.Permissions{auth.Permissions().AuthManageSelf, auth.Permissions().GraphDBRead}, + }, + }, + }, + Session: model.UserSession{}, + }, + }). + OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + Require(). + ResponseStatusCode(http.StatusOK) + + test.Request(t). + WithURL("http//example.com"). + WithContext(&ctx.Context{ + AuthCtx: auth.Context{ + PermissionOverrides: auth.PermissionOverrides{}, + Owner: model.User{ + Roles: model.Roles{ + { + Name: "Big Boy", + Description: "The big boy.", + Permissions: model.Permissions{auth.Permissions().AuthManageSelf, auth.Permissions().GraphDBRead}, + }, + }, + }, + Session: model.UserSession{}, + }, + }). + OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().GraphDBRead)). + Require(). + ResponseStatusCode(http.StatusOK) + + test.Request(t). + WithURL("http//example.com"). + WithContext(&ctx.Context{ + AuthCtx: auth.Context{ + PermissionOverrides: auth.PermissionOverrides{}, + Owner: model.User{ + Roles: model.Roles{ + { + Name: "Big Boy", + Description: "The big boy.", + Permissions: model.Permissions{auth.Permissions().AuthManageSelf, auth.Permissions().GraphDBRead}, + }, + }, + }, + Session: model.UserSession{}, + }, + }). + OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().GraphDBWrite)). Require(). ResponseStatusCode(http.StatusForbidden) @@ -105,7 +223,7 @@ func TestPermissionsCheck(t *testing.T) { Session: model.UserSession{}, }, }). - OnHandler(permissionsCheckHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). + OnHandler(permissionsCheckAtLeastOneHandler(handlerReturn200, auth.Permissions().AuthManageSelf)). Require(). ResponseStatusCode(http.StatusOK) } diff --git a/cmd/api/src/api/registration/registration.go b/cmd/api/src/api/registration/registration.go index b81f4f2757..4a5e9e3611 100644 --- a/cmd/api/src/api/registration/registration.go +++ b/cmd/api/src/api/registration/registration.go @@ -17,6 +17,8 @@ package registration import ( + "net/http" + "github.com/specterops/bloodhound/cache" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/src/api" @@ -28,19 +30,19 @@ import ( "github.com/specterops/bloodhound/src/config" "github.com/specterops/bloodhound/src/daemons/datapipe" "github.com/specterops/bloodhound/src/database" - "net/http" + "github.com/specterops/bloodhound/src/queries" ) func RegisterFossGlobalMiddleware(routerInst *router.Router, cfg config.Configuration, identityResolver auth.IdentityResolver, authenticator api.Authenticator) { - // Set up logging - if cfg.EnableAPILogging { - routerInst.UsePrerouting(middleware.LoggingMiddleware(cfg, identityResolver)) - } - // Set up the middleware stack routerInst.UsePrerouting(middleware.ContextMiddleware) routerInst.UsePrerouting(middleware.CORSMiddleware()) + // Set up logging. This must be done after ContextMiddleware is initialized so the context can be accessed in the log logic + if cfg.EnableAPILogging { + routerInst.UsePrerouting(middleware.LoggingMiddleware(cfg, identityResolver)) + } + routerInst.UsePostrouting( middleware.PanicHandler, middleware.AuthMiddleware(authenticator), @@ -49,12 +51,16 @@ func RegisterFossGlobalMiddleware(routerInst *router.Router, cfg config.Configur } func RegisterFossRoutes( - routerInst *router.Router, cfg config.Configuration, db database.Database, graphDB graph.Database, - apiCache cache.Cache, graphQueryCache cache.Cache, collectorManifests config.CollectorManifests, - authenticator api.Authenticator, taskNotifier datapipe.Tasker, + routerInst *router.Router, + cfg config.Configuration, + rdms *database.BloodhoundDB, + graphDB *graph.DatabaseSwitch, + graphQuery queries.Graph, + apiCache cache.Cache, + collectorManifests config.CollectorManifests, + authenticator api.Authenticator, + taskNotifier datapipe.Tasker, ) { - var resources = v2.NewResources(db, graphDB, cfg, apiCache, graphQueryCache, collectorManifests, taskNotifier) - router.With(middleware.DefaultRateLimitMiddleware, // Health Endpoint routerInst.GET("/health", func(response http.ResponseWriter, _ *http.Request) { @@ -70,5 +76,6 @@ func RegisterFossRoutes( routerInst.PathPrefix("/ui", static.Handler()), ) + var resources = v2.NewResources(rdms, graphDB, cfg, apiCache, graphQuery, collectorManifests, taskNotifier) NewV2API(cfg, resources, routerInst, authenticator) } diff --git a/cmd/api/src/api/router/router.go b/cmd/api/src/api/router/router.go index d1f8b6d638..bf2c47b4d5 100644 --- a/cmd/api/src/api/router/router.go +++ b/cmd/api/src/api/router/router.go @@ -68,8 +68,15 @@ func (s *Route) RequireAuth() *Route { return s.RequirePermissions() } +// Ensure that the requestor has all of the listed permissions func (s *Route) RequirePermissions(permissions ...model.Permission) *Route { - s.handler.Use(middleware.PermissionsCheck(s.authorizer, permissions...)) + s.handler.Use(middleware.PermissionsCheckAll(s.authorizer, permissions...)) + return s +} + +// Ensure that the requestor has at least one of the listed permissions +func (s *Route) RequireAtLeastOnePermission(permissions ...model.Permission) *Route { + s.handler.Use(middleware.PermissionsCheckAtLeastOne(s.authorizer, permissions...)) return s } diff --git a/cmd/api/src/api/tools/dbswitch.go b/cmd/api/src/api/tools/dbswitch.go new file mode 100644 index 0000000000..97e9972d94 --- /dev/null +++ b/cmd/api/src/api/tools/dbswitch.go @@ -0,0 +1,95 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "errors" + "github.com/jackc/pgx/v5" + "github.com/specterops/bloodhound/log" + "github.com/specterops/bloodhound/src/config" +) + +func newPostgresqlConnection(ctx context.Context, cfg config.Configuration) (*pgx.Conn, error) { + if pgCfg, err := pgx.ParseConfig(cfg.Database.PostgreSQLConnectionString()); err != nil { + return nil, err + } else { + return pgx.ConnectConfig(ctx, pgCfg) + } +} + +func HasGraphDriverSet(ctx context.Context, pgxConn *pgx.Conn) (bool, error) { + var ( + exists bool + row = pgxConn.QueryRow(ctx, `select exists(select * from database_switch limit 1);`) + ) + + return exists, row.Scan(&exists) +} + +func GetGraphDriver(ctx context.Context, pgxConn *pgx.Conn) (string, error) { + var ( + driverName string + row = pgxConn.QueryRow(ctx, `select driver from database_switch limit 1;`) + ) + + return driverName, row.Scan(&driverName) +} + +func SetGraphDriver(ctx context.Context, cfg config.Configuration, driverName string) error { + if pgxConn, err := newPostgresqlConnection(ctx, cfg); err != nil { + return err + } else { + defer pgxConn.Close(ctx) + + if hasDriver, err := HasGraphDriverSet(ctx, pgxConn); err != nil { + return err + } else if hasDriver { + _, err := pgxConn.Exec(ctx, `update database_switch set driver = $1;`, driverName) + return err + } else { + _, err := pgxConn.Exec(ctx, `insert into database_switch (driver) values ($1);`, driverName) + return err + } + } +} + +func LookupGraphDriver(ctx context.Context, cfg config.Configuration) (string, error) { + driverName := cfg.GraphDriver + + if pgxConn, err := newPostgresqlConnection(ctx, cfg); err != nil { + return "", err + } else { + defer pgxConn.Close(ctx) + + if _, err := pgxConn.Exec(ctx, `create table if not exists database_switch (driver text not null, primary key(driver));`); err != nil { + return "", err + } + + if setDriverName, err := GetGraphDriver(ctx, pgxConn); err != nil { + if errors.Is(err, pgx.ErrNoRows) { + log.Infof("No database driver has been set for migration, using: %s", driverName) + } else { + return "", err + } + } else { + driverName = setDriverName + } + } + + return driverName, nil +} diff --git a/cmd/api/src/api/tools/flag.go b/cmd/api/src/api/tools/flag.go index a070bd35cc..f3a6558d77 100644 --- a/cmd/api/src/api/tools/flag.go +++ b/cmd/api/src/api/tools/flag.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package tools diff --git a/cmd/api/src/api/tools/pg.go b/cmd/api/src/api/tools/pg.go new file mode 100644 index 0000000000..3a0a4da0fe --- /dev/null +++ b/cmd/api/src/api/tools/pg.go @@ -0,0 +1,349 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package tools + +import ( + "context" + "fmt" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" + "github.com/specterops/bloodhound/dawgs" + "github.com/specterops/bloodhound/dawgs/drivers/neo4j" + "github.com/specterops/bloodhound/dawgs/drivers/pg" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/dawgs/util/size" + "github.com/specterops/bloodhound/log" + "github.com/specterops/bloodhound/src/api" + "github.com/specterops/bloodhound/src/config" + "net/http" + "sync" +) + +type MigratorState string + +const ( + stateIdle MigratorState = "idle" + stateMigrating MigratorState = "migrating" + stateCanceling MigratorState = "canceling" +) + +func migrateTypes(ctx context.Context, neoDB, pgDB graph.Database) error { + defer log.LogAndMeasure(log.LevelInfo, "Migrating kinds from Neo4j to PostgreSQL")() + + var ( + neoNodeKinds graph.Kinds + neoEdgeKinds graph.Kinds + ) + + if err := neoDB.ReadTransaction(ctx, func(tx graph.Transaction) error { + var ( + nextKindStr string + result = tx.Raw("call db.labels();", nil) + ) + + for result.Next() { + if err := result.Scan(&nextKindStr); err != nil { + return err + } + + neoNodeKinds = append(neoNodeKinds, graph.StringKind(nextKindStr)) + } + + if err := result.Error(); err != nil { + return err + } + + result = tx.Raw("call db.relationshipTypes();", nil) + + for result.Next() { + if err := result.Scan(&nextKindStr); err != nil { + return err + } + + neoEdgeKinds = append(neoEdgeKinds, graph.StringKind(nextKindStr)) + } + + return nil + }); err != nil { + return err + } + + return pgDB.WriteTransaction(ctx, func(tx graph.Transaction) error { + _, err := pgDB.(*pg.Driver).KindMapper().AssertKinds(tx, append(neoNodeKinds, neoEdgeKinds...)) + return err + }) +} + +func convertNeo4jProperties(properties *graph.Properties) error { + for key, propertyValue := range properties.Map { + switch typedPropertyValue := propertyValue.(type) { + case dbtype.Date: + properties.Map[key] = typedPropertyValue.Time() + + case dbtype.Duration: + return fmt.Errorf("unsupported conversion") + + case dbtype.Time: + properties.Map[key] = typedPropertyValue.Time() + + case dbtype.LocalTime: + properties.Map[key] = typedPropertyValue.Time() + + case dbtype.LocalDateTime: + properties.Map[key] = typedPropertyValue.Time() + } + } + + return nil +} + +func migrateNodes(ctx context.Context, neoDB, pgDB graph.Database) (map[graph.ID]graph.ID, error) { + defer log.LogAndMeasure(log.LevelInfo, "Migrating nodes from Neo4j to PostgreSQL")() + + var ( + // Start at 2 and assume that the first node of the graph is the graph schema migration information + nextNodeID = graph.ID(2) + nodeIDMappings = map[graph.ID]graph.ID{} + ) + + if err := neoDB.ReadTransaction(ctx, func(tx graph.Transaction) error { + return tx.Nodes().Fetch(func(cursor graph.Cursor[*graph.Node]) error { + if err := pgDB.BatchOperation(ctx, func(tx graph.Batch) error { + for next := range cursor.Chan() { + if err := convertNeo4jProperties(next.Properties); err != nil { + return err + } + + if err := tx.CreateNode(graph.NewNode(nextNodeID, next.Properties, next.Kinds...)); err != nil { + return err + } else { + nodeIDMappings[next.ID] = nextNodeID + nextNodeID++ + } + } + + return nil + }); err != nil { + return err + } + + return cursor.Error() + }) + }); err != nil { + return nil, err + } + + return nodeIDMappings, pgDB.Run(ctx, fmt.Sprintf(`alter sequence node_id_seq restart with %d`, nextNodeID), nil) +} + +func migrateEdges(ctx context.Context, neoDB, pgDB graph.Database, nodeIDMappings map[graph.ID]graph.ID) error { + defer log.LogAndMeasure(log.LevelInfo, "Migrating edges from Neo4j to PostgreSQL")() + + return neoDB.ReadTransaction(ctx, func(tx graph.Transaction) error { + return tx.Relationships().Fetch(func(cursor graph.Cursor[*graph.Relationship]) error { + if err := pgDB.BatchOperation(ctx, func(tx graph.Batch) error { + for next := range cursor.Chan() { + var ( + pgStartID = nodeIDMappings[next.StartID] + pgEndID = nodeIDMappings[next.EndID] + ) + + if err := convertNeo4jProperties(next.Properties); err != nil { + return err + } + + if err := tx.CreateRelationship(&graph.Relationship{ + StartID: pgStartID, + EndID: pgEndID, + Kind: next.Kind, + Properties: next.Properties, + }); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + return cursor.Error() + }) + }) +} + +type PGMigrator struct { + graphSchema graph.Schema + graphDBSwitch *graph.DatabaseSwitch + serverCtx context.Context + migrationCancelFunc func() + state MigratorState + lock *sync.Mutex + cfg config.Configuration +} + +func NewPGMigrator(serverCtx context.Context, cfg config.Configuration, graphSchema graph.Schema, graphDBSwitch *graph.DatabaseSwitch) *PGMigrator { + return &PGMigrator{ + graphSchema: graphSchema, + graphDBSwitch: graphDBSwitch, + serverCtx: serverCtx, + state: stateIdle, + lock: &sync.Mutex{}, + cfg: cfg, + } +} + +func (s *PGMigrator) advanceState(next MigratorState, validTransitions ...MigratorState) error { + s.lock.Lock() + defer s.lock.Unlock() + + isValid := false + + for _, validTransition := range validTransitions { + if s.state == validTransition { + isValid = true + break + } + } + + if !isValid { + return fmt.Errorf("migrator state is %s but expected one of: %v", s.state, validTransitions) + } + + s.state = next + return nil +} + +func (s *PGMigrator) SwitchPostgreSQL(response http.ResponseWriter, request *http.Request) { + if pgDB, err := dawgs.Open(s.serverCtx, pg.DriverName, dawgs.Config{ + TraversalMemoryLimit: size.Gibibyte, + DriverCfg: s.cfg.Database.PostgreSQLConnectionString(), + }); err != nil { + api.WriteJSONResponse(request.Context(), map[string]any{ + "error": fmt.Errorf("failed connecting to PostgreSQL: %w", err), + }, http.StatusInternalServerError, response) + } else if err := SetGraphDriver(request.Context(), s.cfg, pg.DriverName); err != nil { + api.WriteJSONResponse(request.Context(), map[string]any{ + "error": fmt.Errorf("failed updating graph database driver preferences: %w", err), + }, http.StatusInternalServerError, response) + } else { + s.graphDBSwitch.Switch(pgDB) + response.WriteHeader(http.StatusOK) + + log.Infof("Updated default graph driver to PostgreSQL") + } +} + +func (s *PGMigrator) SwitchNeo4j(response http.ResponseWriter, request *http.Request) { + if neo4jDB, err := dawgs.Open(s.serverCtx, neo4j.DriverName, dawgs.Config{ + TraversalMemoryLimit: size.Gibibyte, + DriverCfg: s.cfg.Neo4J.Neo4jConnectionString(), + }); err != nil { + api.WriteJSONResponse(request.Context(), map[string]any{ + "error": fmt.Errorf("failed connecting to Neo4j: %w", err), + }, http.StatusInternalServerError, response) + } else if err := SetGraphDriver(request.Context(), s.cfg, neo4j.DriverName); err != nil { + api.WriteJSONResponse(request.Context(), map[string]any{ + "error": fmt.Errorf("failed updating graph database driver preferences: %w", err), + }, http.StatusInternalServerError, response) + } else { + s.graphDBSwitch.Switch(neo4jDB) + response.WriteHeader(http.StatusOK) + + log.Infof("Updated default graph driver to Neo4j") + } +} + +func (s *PGMigrator) startMigration() error { + if err := s.advanceState(stateMigrating, stateIdle); err != nil { + return fmt.Errorf("database migration state error: %w", err) + } else if neo4jDB, err := dawgs.Open(s.serverCtx, neo4j.DriverName, dawgs.Config{ + TraversalMemoryLimit: size.Gibibyte, + DriverCfg: s.cfg.Neo4J.Neo4jConnectionString(), + }); err != nil { + return fmt.Errorf("failed connecting to Neo4j: %w", err) + } else if pgDB, err := dawgs.Open(s.serverCtx, pg.DriverName, dawgs.Config{ + TraversalMemoryLimit: size.Gibibyte, + DriverCfg: s.cfg.Database.PostgreSQLConnectionString(), + }); err != nil { + return fmt.Errorf("failed connecting to PostgreSQL: %w", err) + } else { + log.Infof("Dispatching live migration from Neo4j to PostgreSQL") + + migrationCtx, migrationCancelFunc := context.WithCancel(s.serverCtx) + s.migrationCancelFunc = migrationCancelFunc + + go func(ctx context.Context) { + defer migrationCancelFunc() + + log.Infof("Starting live migration from Neo4j to PostgreSQL") + + if err := pgDB.AssertSchema(ctx, s.graphSchema); err != nil { + log.Errorf("Unable to assert graph schema in PostgreSQL: %v", err) + } else if err := migrateTypes(ctx, neo4jDB, pgDB); err != nil { + log.Errorf("Unable to migrate Neo4j kinds to PostgreSQL: %v", err) + } else if nodeIDMappings, err := migrateNodes(ctx, neo4jDB, pgDB); err != nil { + log.Errorf("Failed importing nodes into PostgreSQL: %v", err) + } else if err := migrateEdges(ctx, neo4jDB, pgDB, nodeIDMappings); err != nil { + log.Errorf("Failed importing edges into PostgreSQL: %v", err) + } else { + log.Infof("Migration to PostgreSQL completed successfully") + } + + if err := s.advanceState(stateIdle, stateMigrating, stateCanceling); err != nil { + log.Errorf("Database migration state management error: %v", err) + } + }(migrationCtx) + } + + return nil +} + +func (s *PGMigrator) MigrationStart(response http.ResponseWriter, request *http.Request) { + if err := s.startMigration(); err != nil { + api.WriteJSONResponse(request.Context(), map[string]any{ + "error": err.Error(), + }, http.StatusInternalServerError, response) + } else { + response.WriteHeader(http.StatusAccepted) + } +} + +func (s *PGMigrator) cancelMigration() error { + if err := s.advanceState(stateCanceling, stateMigrating); err != nil { + return err + } + + s.migrationCancelFunc() + + return nil +} + +func (s *PGMigrator) MigrationCancel(response http.ResponseWriter, request *http.Request) { + if err := s.cancelMigration(); err != nil { + api.WriteJSONResponse(request.Context(), map[string]any{ + "error": err.Error(), + }, http.StatusInternalServerError, response) + } else { + response.WriteHeader(http.StatusAccepted) + } +} + +func (s *PGMigrator) MigrationStatus(response http.ResponseWriter, request *http.Request) { + api.WriteJSONResponse(request.Context(), map[string]any{ + "state": s.state, + }, http.StatusOK, response) +} diff --git a/cmd/api/src/api/v2/apiclient/apiclient.go b/cmd/api/src/api/v2/apiclient/apiclient.go index e605b1d97b..882de551ae 100644 --- a/cmd/api/src/api/v2/apiclient/apiclient.go +++ b/cmd/api/src/api/v2/apiclient/apiclient.go @@ -90,7 +90,7 @@ func (s Client) Request(method, path string, params url.Values, body any, header request.Header = header[0] } - // Execute the Request and hand the response back to the user + // query the Request and hand the response back to the user const ( sleepInterval = time.Second * 5 maxSleep = sleepInterval * 5 @@ -149,6 +149,6 @@ func (s Client) Raw(request *http.Request) (*http.Response, error) { } } - // Execute the Request and hand the response back to the user + // query the Request and hand the response back to the user return s.Http.Do(request) } diff --git a/cmd/api/src/api/v2/apiclient/flags.go b/cmd/api/src/api/v2/apiclient/flags.go new file mode 100644 index 0000000000..f85ce62dd4 --- /dev/null +++ b/cmd/api/src/api/v2/apiclient/flags.go @@ -0,0 +1,72 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package apiclient + +import ( + "fmt" + "github.com/specterops/bloodhound/src/api" + "github.com/specterops/bloodhound/src/model/appcfg" + "net/http" +) + +func (s Client) GetFeatureFlags() ([]appcfg.FeatureFlag, error) { + var featureFlags []appcfg.FeatureFlag + + if response, err := s.Request(http.MethodGet, "/api/v2/features", nil, nil); err != nil { + return nil, err + } else { + defer response.Body.Close() + + if api.IsErrorResponse(response) { + return nil, ReadAPIError(response) + } + + return featureFlags, api.ReadAPIV2ResponsePayload(&featureFlags, response) + } +} + +func (s Client) GetFeatureFlag(key string) (appcfg.FeatureFlag, error) { + if flags, err := s.GetFeatureFlags(); err != nil { + return appcfg.FeatureFlag{}, err + } else { + for _, flag := range flags { + if flag.Key == key { + return flag, nil + } + } + } + + return appcfg.FeatureFlag{}, fmt.Errorf("flag with key %s not found", key) +} + +func (s Client) ToggleFeatureFlag(key string) error { + var result appcfg.Parameter + + if flag, err := s.GetFeatureFlag(key); err != nil { + return err + } else if response, err := s.Request(http.MethodPut, fmt.Sprintf("/api/v2/features/%d/toggle", flag.ID), nil, nil); err != nil { + return err + } else { + defer response.Body.Close() + + if api.IsErrorResponse(response) { + return ReadAPIError(response) + } + + return api.ReadAPIV2ResponsePayload(&result, response) + } +} diff --git a/cmd/api/src/api/v2/app_config_integration_test.go b/cmd/api/src/api/v2/app_config_integration_test.go index 637c263899..1c7c595077 100644 --- a/cmd/api/src/api/v2/app_config_integration_test.go +++ b/cmd/api/src/api/v2/app_config_integration_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 //go:build serial_integration @@ -22,11 +22,11 @@ package v2_test import ( "testing" + "github.com/specterops/bloodhound/dawgs/drivers/neo4j" v2 "github.com/specterops/bloodhound/src/api/v2" "github.com/specterops/bloodhound/src/api/v2/integration" "github.com/specterops/bloodhound/src/model/appcfg" "github.com/stretchr/testify/require" - "github.com/specterops/bloodhound/dawgs/drivers/neo4j" ) func Test_GetAppConfigs(t *testing.T) { @@ -35,7 +35,7 @@ func Test_GetAppConfigs(t *testing.T) { neo4jConfigsFound = false passwordExpirationValue appcfg.PasswordExpiration neo4jParametersValue appcfg.Neo4jParameters - testCtx = integration.NewContext(t, integration.StartBHServer) + testCtx = integration.NewFOSSContext(t) ) config, err := testCtx.AdminClient().GetAppConfigs() @@ -66,7 +66,7 @@ func Test_GetAppConfigs(t *testing.T) { func Test_GetAppConfigWithParameter(t *testing.T) { var ( passwordExpirationValue appcfg.PasswordExpiration - testCtx = integration.NewContext(t, integration.StartBHServer) + testCtx = integration.NewFOSSContext(t) ) config, err := testCtx.AdminClient().GetAppConfig(appcfg.PasswordExpirationWindow) @@ -90,7 +90,7 @@ func Test_PutAppConfig(t *testing.T) { "duration": updatedDuration, }, } - testCtx = integration.NewContext(t, integration.StartBHServer) + testCtx = integration.NewFOSSContext(t) ) parameter, err := testCtx.AdminClient().PutAppConfig(updatedPasswordExpirationWindowParameter) diff --git a/cmd/api/src/api/v2/audit_integration_test.go b/cmd/api/src/api/v2/audit_integration_test.go index 43b0f9253f..e27d0f7101 100644 --- a/cmd/api/src/api/v2/audit_integration_test.go +++ b/cmd/api/src/api/v2/audit_integration_test.go @@ -28,7 +28,7 @@ import ( ) func Test_ListAuditLogs(t *testing.T) { - testCtx := integration.NewContext(t, integration.StartBHServer) + testCtx := integration.NewFOSSContext(t) t.Run("Test Getting Latest Audit Logs", func(t *testing.T) { var ( diff --git a/cmd/api/src/api/v2/auth_integration_test.go b/cmd/api/src/api/v2/auth_integration_test.go index d7dc49a6a7..b39f3739aa 100644 --- a/cmd/api/src/api/v2/auth_integration_test.go +++ b/cmd/api/src/api/v2/auth_integration_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 //go:build serial_integration @@ -23,11 +23,11 @@ import ( "net/http" "testing" + "github.com/specterops/bloodhound/errors" "github.com/specterops/bloodhound/src/api" "github.com/specterops/bloodhound/src/api/v2/integration" "github.com/specterops/bloodhound/src/auth" "github.com/stretchr/testify/require" - "github.com/specterops/bloodhound/errors" ) const ( @@ -38,7 +38,7 @@ const ( func Test_PermissionHandling(t *testing.T) { var ( - testCtx = integration.NewContext(t, integration.StartBHServer) + testCtx = integration.NewFOSSContext(t) newUser = testCtx.CreateUser(otherUser, otherUser, auth.RoleReadOnly) newUserToken = testCtx.CreateAuthToken(newUser.ID, "TestToken") newUserClient = testCtx.NewAPIClientWithToken(newUserToken) @@ -53,7 +53,7 @@ func Test_PermissionHandling(t *testing.T) { func Test_AuthRolesMatchInternalDefinitions(t *testing.T) { var ( - testCtx = integration.NewContext(t, integration.StartBHServer) + testCtx = integration.NewFOSSContext(t) actualRoles = testCtx.ListRoles() ) @@ -68,7 +68,7 @@ func Test_AuthRolesMatchInternalDefinitions(t *testing.T) { func Test_UserManagement(t *testing.T) { var ( - testCtx = integration.NewContext(t, integration.StartBHServer) + testCtx = integration.NewFOSSContext(t) newUser = testCtx.CreateUser(otherUser, otherUser, auth.RoleReadOnly) ) @@ -133,7 +133,7 @@ func Test_UserManagement(t *testing.T) { func Test_NonAdminFunctionality(t *testing.T) { var ( - testCtx = integration.NewContext(t, integration.StartBHServer) + testCtx = integration.NewFOSSContext(t) newUser = testCtx.CreateUser(otherUser, otherUser, auth.RoleReadOnly) nonAdminUser = testCtx.CreateUser(nonAdmin, nonAdmin, auth.RoleUser) nonAdminToken = testCtx.CreateAuthToken(nonAdminUser.ID, "NonAdmin Token") diff --git a/cmd/api/src/api/v2/azure_integration_test.go b/cmd/api/src/api/v2/azure_integration_test.go index 854a9c18c9..807e6da655 100644 --- a/cmd/api/src/api/v2/azure_integration_test.go +++ b/cmd/api/src/api/v2/azure_integration_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 //go:build integration @@ -22,24 +22,25 @@ package v2_test import ( "context" "encoding/json" + schema "github.com/specterops/bloodhound/graphschema" "testing" - v2 "github.com/specterops/bloodhound/src/api/v2" - "github.com/specterops/bloodhound/src/test/integration" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/analysis/azure" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/graphschema/common" + v2 "github.com/specterops/bloodhound/src/api/v2" + "github.com/specterops/bloodhound/src/test/integration" + "github.com/stretchr/testify/require" ) func TestGetAZEntityInformation(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) testContext.TransactionalTest(func(harness integration.HarnessDetails, tx graph.Transaction) { objectID, err := harness.AZGroupMembership.Group.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) - groupInformation, err := v2.GetAZEntityInformation(context.Background(), testContext.GraphDB, "groups", objectID, true) + groupInformation, err := v2.GetAZEntityInformation(context.Background(), testContext.Graph.Database, "groups", objectID, true) require.Nil(t, err) groupInformationJSON, err := json.Marshal(groupInformation) diff --git a/cmd/api/src/api/v2/cypher_search_integration_test.go b/cmd/api/src/api/v2/cypher_search_integration_test.go index f418e516a9..31367aadbd 100644 --- a/cmd/api/src/api/v2/cypher_search_integration_test.go +++ b/cmd/api/src/api/v2/cypher_search_integration_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 //go:build integration @@ -22,13 +22,13 @@ package v2_test import ( "testing" + "github.com/specterops/bloodhound/cypher/frontend" + "github.com/specterops/bloodhound/graphschema/common" + "github.com/specterops/bloodhound/lab" v2 "github.com/specterops/bloodhound/src/api/v2" "github.com/specterops/bloodhound/src/test/lab/fixtures" "github.com/specterops/bloodhound/src/test/lab/harnesses" - "github.com/specterops/bloodhound/lab" "github.com/stretchr/testify/require" - "github.com/specterops/bloodhound/cypher/frontend" - "github.com/specterops/bloodhound/graphschema/common" ) func Test_CypherSearch(t *testing.T) { @@ -70,7 +70,7 @@ func Test_CypherSearch(t *testing.T) { assert.True(ok) graphResponse, err := apiClient.CypherSearch(v2.CypherSearch{ - Query: "match (n:Computer) return n", + Query: "match (n:Computer) where n.objectid = '" + fixtures.BasicComputerSID.String() + "' return n", }) assert.NoError(err) assert.Equal(1, len(graphResponse.Nodes)) diff --git a/cmd/api/src/api/v2/file_uploads_integration_test.go b/cmd/api/src/api/v2/file_uploads_integration_test.go index bf2acb2959..712b8a69a2 100644 --- a/cmd/api/src/api/v2/file_uploads_integration_test.go +++ b/cmd/api/src/api/v2/file_uploads_integration_test.go @@ -34,7 +34,7 @@ import ( ) func Test_FileUpload(t *testing.T) { - testCtx := integration.NewContext(t, integration.StartBHServer) + testCtx := integration.NewFOSSContext(t) apiClient := testCtx.AdminClient() loader := testCtx.FixtureLoader @@ -133,7 +133,7 @@ func Test_FileUpload(t *testing.T) { } func Test_FileUploadWorkFlowVersion5(t *testing.T) { - testCtx := integration.NewContext(t, integration.StartBHServer) + testCtx := integration.NewFOSSContext(t) testCtx.SendFileIngest([]string{ "v5/ingest/domains.json", @@ -152,7 +152,7 @@ func Test_FileUploadWorkFlowVersion5(t *testing.T) { } func Test_FileUploadWorkFlowVersion6(t *testing.T) { - testCtx := integration.NewContext(t, integration.StartBHServer) + testCtx := integration.NewFOSSContext(t) testCtx.SendFileIngest([]string{ "v6/ingest/domains.json", @@ -168,17 +168,12 @@ func Test_FileUploadWorkFlowVersion6(t *testing.T) { //Assert that we created stuff we expected testCtx.AssertIngest(fixtures.IngestAssertions) + testCtx.AssertIngest(fixtures.IngestAssertionsv6) } func Test_FileUploadVersion6AllOptionADCS(t *testing.T) { - testCtx := integration.NewContext(t, integration.StartBHServer) - - if adcsFlag, err := testCtx.DB.GetFlagByKey("adcs"); err != nil { - t.Fatalf("unable to get adcs flag: %v", err) - } else { - adcsFlag.Enabled = true - testCtx.DB.SetFlag(adcsFlag) - } + testCtx := integration.NewFOSSContext(t) + testCtx.ToggleFeatureFlag("adcs") testCtx.SendFileIngest([]string{ "v6/all/aiacas.json", @@ -199,7 +194,7 @@ func Test_FileUploadVersion6AllOptionADCS(t *testing.T) { } func Test_CompressedFileUploadWorkFlowVersion5(t *testing.T) { - testCtx := integration.NewContext(t, integration.StartBHServer) + testCtx := integration.NewFOSSContext(t) testCtx.SendCompressedFileIngest([]string{ "v5/ingest/domains.json", @@ -218,7 +213,7 @@ func Test_CompressedFileUploadWorkFlowVersion5(t *testing.T) { } func Test_CompressedFileUploadWorkFlowVersion6(t *testing.T) { - testCtx := integration.NewContext(t, integration.StartBHServer) + testCtx := integration.NewFOSSContext(t) testCtx.SendCompressedFileIngest([]string{ "v6/ingest/domains.json", @@ -234,4 +229,5 @@ func Test_CompressedFileUploadWorkFlowVersion6(t *testing.T) { //Assert that we created stuff we expected testCtx.AssertIngest(fixtures.IngestAssertions) + testCtx.AssertIngest(fixtures.IngestAssertionsv6) } diff --git a/cmd/api/src/api/v2/integration/api.go b/cmd/api/src/api/v2/integration/api.go index ce470d258f..4b865ec4ed 100644 --- a/cmd/api/src/api/v2/integration/api.go +++ b/cmd/api/src/api/v2/integration/api.go @@ -18,79 +18,25 @@ package integration import ( "context" - "fmt" + "github.com/specterops/bloodhound/src/config" + "github.com/specterops/bloodhound/src/daemons" + "github.com/specterops/bloodhound/src/services" "net/http" "time" - "github.com/specterops/bloodhound/cache" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/log" "github.com/specterops/bloodhound/src/api" - "github.com/specterops/bloodhound/src/api/registration" - "github.com/specterops/bloodhound/src/api/router" - "github.com/specterops/bloodhound/src/auth" - "github.com/specterops/bloodhound/src/config" - "github.com/specterops/bloodhound/src/daemons" - "github.com/specterops/bloodhound/src/daemons/api/bhapi" - "github.com/specterops/bloodhound/src/daemons/datapipe" - "github.com/specterops/bloodhound/src/daemons/gc" + "github.com/specterops/bloodhound/src/bootstrap" "github.com/specterops/bloodhound/src/database" - "github.com/specterops/bloodhound/src/server" - "github.com/specterops/bloodhound/src/test/integration" "github.com/specterops/bloodhound/src/test/integration/utils" ) -type APIServerContext struct { - Context context.Context - DB *database.BloodhoundDB - GraphDB graph.Database - Configuration config.Configuration - APICache cache.Cache - GraphQueryCache cache.Cache -} - -type APIStartFunc func(ctx APIServerContext) error - -func StartBHServer(apiServerContext APIServerContext) error { - if err := server.InitializeLogging(apiServerContext.Configuration); err != nil { - return fmt.Errorf("log initialization error: %w", err) - } - +func (s *Context) APIServerURL(paths ...string) string { var ( - serviceManager = daemons.NewManager(server.DefaultServerShutdownTimeout) - sessionSweepingService = gc.NewDataPruningDaemon(apiServerContext.DB) - routerInst = router.NewRouter(apiServerContext.Configuration, auth.NewAuthorizer(), server.ContentSecurityPolicy) - fakeManifests = config.CollectorManifests{} - datapipeDaemon = datapipe.NewDaemon(apiServerContext.Configuration, apiServerContext.DB, apiServerContext.GraphDB, apiServerContext.GraphQueryCache, time.Second) - authenticator = api.NewAuthenticator(apiServerContext.Configuration, apiServerContext.DB, database.NewContextInitializer(apiServerContext.DB)) - ) - - registration.RegisterFossGlobalMiddleware(&routerInst, apiServerContext.Configuration, auth.NewIdentityResolver(), authenticator) - registration.RegisterFossRoutes( - &routerInst, - apiServerContext.Configuration, - apiServerContext.DB, - apiServerContext.GraphDB, - apiServerContext.APICache, - apiServerContext.GraphQueryCache, - fakeManifests, - authenticator, - datapipeDaemon, + cfg = s.GetConfiguration() + fullPath, err = api.NewJoinedURL(cfg.RootURL.String(), paths...) ) - apiDaemon := bhapi.NewDaemon(apiServerContext.Configuration, routerInst.Handler()) - - // Start daemons - serviceManager.Start(apiDaemon, sessionSweepingService, datapipeDaemon) - - // Wait for a signal to exit - <-apiServerContext.Context.Done() - serviceManager.Stop() - - return nil -} - -func (s *Context) APIServerURL(paths ...string) string { - fullPath, err := api.NewJoinedURL(s.cfg.RootURL.String(), paths...) if err != nil { s.TestCtrl.Fatalf("Bad API server URL paths specified: %v. Paths: %v", err, paths) @@ -126,47 +72,31 @@ func (s *Context) WaitForAPI(timeout time.Duration) { } // EnableAPI loads all dependencies and starts up a new API server -func (s *Context) EnableAPI(startFunc APIStartFunc) { - log.Infof("Starting up integration test harness") - +func (s *Context) EnableAPI() { if cfg, err := utils.LoadIntegrationTestConfig(); err != nil { s.TestCtrl.Fatalf("Failed loading integration test config: %v", err) - } else if err := server.EnsureServerDirectories(cfg); err != nil { - s.TestCtrl.Fatalf("Failed ensuring integration test directories: %v", err) - } else if db, graphDB, err := server.ConnectDatabases(cfg); err != nil { - s.TestCtrl.Fatalf("Failed connecting to databases: %v", err) - } else if err := integration.Prepare(db); err != nil { - s.TestCtrl.Fatalf("Failed ensuring database: %v", err) - } else if err := server.MigrateDB(cfg, db); err != nil { - s.TestCtrl.Fatalf("Failed migrating database: %v", err) - } else if err := server.MigrateGraph(cfg, graphDB); err != nil { - s.TestCtrl.Fatalf("Failed migrating Graph database: %v", err) - } else if apiCache, err := cache.NewCache(cache.Config{MaxSize: cfg.MaxAPICacheSize}); err != nil { - s.TestCtrl.Fatalf("Failed to create in-memory cache for API: %v", err) - } else if graphQueryCache, err := cache.NewCache(cache.Config{MaxSize: cfg.MaxGraphQueryCacheSize}); err != nil { - s.TestCtrl.Fatalf("Failed to create in-memory cache for graphDB: %v", err) } else { - s.DB = db - s.Graph = graphDB - s.cfg = &cfg - // Start the HTTP API - s.WaitGroup.Add(1) + s.waitGroup.Add(1) go func() { - defer s.WaitGroup.Done() - - if err := startFunc(APIServerContext{ - Context: s.Ctx, - DB: db, - GraphDB: graphDB, - Configuration: cfg, - APICache: apiCache, - GraphQueryCache: graphQueryCache, - }); err != nil { - fmt.Printf("Error running HTTP API: %v", err) + defer s.waitGroup.Done() + + initializer := bootstrap.Initializer[*database.BloodhoundDB, *graph.DatabaseSwitch]{ + Configuration: cfg, + DBConnector: services.ConnectDatabases, + Entrypoint: func(ctx context.Context, cfg config.Configuration, databaseConnections bootstrap.DatabaseConnections[*database.BloodhoundDB, *graph.DatabaseSwitch]) ([]daemons.Daemon, error) { + if err := databaseConnections.RDMS.Wipe(); err != nil { + return nil, err + } + + return services.Entrypoint(ctx, cfg, databaseConnections) + }, } - }() + if err := initializer.Launch(s.ctx, false); err != nil { + log.Errorf("Failed launching API server: %v", err) + } + }() } // Wait, at most, 30 seconds for the API to boot diff --git a/cmd/api/src/api/v2/integration/apiclient.go b/cmd/api/src/api/v2/integration/apiclient.go index 8a22e35f0c..af66509fd3 100644 --- a/cmd/api/src/api/v2/integration/apiclient.go +++ b/cmd/api/src/api/v2/integration/apiclient.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package integration @@ -28,7 +28,7 @@ const ( ) func (s *Context) newAPIClient() apiclient.Client { - authClient, err := apiclient.NewClient(s.cfg.RootURL.String()) + authClient, err := apiclient.NewClient(s.GetRootURL().String()) require.Nil(s.TestCtrl, err, "Unable to create auth client: %v", err) return authClient diff --git a/cmd/api/src/api/v2/integration/config.go b/cmd/api/src/api/v2/integration/config.go index a7487ed179..0c60adfa06 100644 --- a/cmd/api/src/api/v2/integration/config.go +++ b/cmd/api/src/api/v2/integration/config.go @@ -1,23 +1,24 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package integration import ( "github.com/specterops/bloodhound/src/config" + "github.com/specterops/bloodhound/src/serde" "github.com/specterops/bloodhound/src/test/integration/utils" ) @@ -39,3 +40,8 @@ func (s *Context) GetConfiguration() config.Configuration { return *s.cfg } + +func (s *Context) GetRootURL() *serde.URL { + cfg := s.GetConfiguration() + return &cfg.RootURL +} diff --git a/cmd/api/src/api/v2/integration/context.go b/cmd/api/src/api/v2/integration/context.go index cfa6d0753c..b6ef52aa09 100644 --- a/cmd/api/src/api/v2/integration/context.go +++ b/cmd/api/src/api/v2/integration/context.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package integration @@ -22,39 +22,35 @@ import ( "github.com/specterops/bloodhound/src/api/v2/apiclient" "github.com/specterops/bloodhound/src/config" - "github.com/specterops/bloodhound/src/database" "github.com/specterops/bloodhound/src/test" "github.com/specterops/bloodhound/src/test/fixtures" - "github.com/specterops/bloodhound/dawgs/graph" ) // Context holds integration test relevant information to be passed around to functions type Context struct { - adminClient *apiclient.Client - cfg *config.Configuration - DB database.Database - Graph graph.Database FixtureLoader fixtures.Loader TestCtrl test.Controller - Ctx context.Context - CtxDoneFunc func() - WaitGroup *sync.WaitGroup + adminClient *apiclient.Client + cfg *config.Configuration + ctx context.Context + ctxDoneFunc func() + waitGroup *sync.WaitGroup } -// NewContext creates a new integration Context -func NewContext(testCtrl test.Controller, startFunc APIStartFunc) Context { +// NewFOSSContext creates a new integration Context configured for BHCE +func NewFOSSContext(testCtrl test.Controller) Context { ctx, ctxDoneFunc := context.WithCancel(context.Background()) testCtx := Context{ TestCtrl: testCtrl, - Ctx: ctx, - CtxDoneFunc: ctxDoneFunc, + ctx: ctx, + ctxDoneFunc: ctxDoneFunc, FixtureLoader: fixtures.NewLoader(fixtures.NewTestErrorHandler(testCtrl)), - WaitGroup: &sync.WaitGroup{}, + waitGroup: &sync.WaitGroup{}, } // Enable the API - testCtx.EnableAPI(startFunc) + testCtx.EnableAPI() // Register teardown after starting the server since we have now mutated the environment testCtrl.Cleanup(testCtx.Teardown) @@ -62,12 +58,8 @@ func NewContext(testCtrl test.Controller, startFunc APIStartFunc) Context { return testCtx } -func (s *Context) Init() { - -} - // Teardown stops the integration test server func (s *Context) Teardown() { - s.CtxDoneFunc() - s.WaitGroup.Wait() + s.ctxDoneFunc() + s.waitGroup.Wait() } diff --git a/cmd/api/src/api/v2/integration/database.go b/cmd/api/src/api/v2/integration/database.go deleted file mode 100644 index 1c3661cbeb..0000000000 --- a/cmd/api/src/api/v2/integration/database.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2023 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package integration - -import ( - "context" - - "github.com/specterops/bloodhound/dawgs/graph" - "github.com/specterops/bloodhound/dawgs/query" - "github.com/specterops/bloodhound/graphschema/ad" - "github.com/specterops/bloodhound/graphschema/azure" - "github.com/specterops/bloodhound/src/database" - "github.com/specterops/bloodhound/src/server" - "github.com/specterops/bloodhound/src/test/integration" -) - -func (s *Context) initDatabase() { - cfg := s.GetConfiguration() - - if db, err := server.ConnectPostgres(cfg); err != nil { - s.TestCtrl.Fatalf("Failed connecting to databases: %v", err) - } else if err := integration.Prepare(db); err != nil { - s.TestCtrl.Fatalf("Failed preparing DB: %v", err) - } else if err := server.MigrateDB(cfg, db); err != nil { - s.TestCtrl.Fatalf("Failed migrating DB: %v", err) - } else { - s.DB = db - } -} - -func (s *Context) GetDatabase() database.Database { - // If the database has not been initialized, bring it up first - if s.DB == nil { - s.initDatabase() - } - - return s.DB -} - -func (s *Context) ClearGraphDB() error { - return s.Graph.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - return tx.Nodes().Filterf(func() graph.Criteria { - return query.KindIn(query.Node(), ad.Entity, azure.Entity) - }).Delete() - }) -} diff --git a/cmd/api/src/api/v2/integration/ingest.go b/cmd/api/src/api/v2/integration/ingest.go index be0180220a..833d77e17b 100644 --- a/cmd/api/src/api/v2/integration/ingest.go +++ b/cmd/api/src/api/v2/integration/ingest.go @@ -38,6 +38,10 @@ func ingestPayload(t test.Controller, loader fixtures.Loader, fixturePath string return payload } +func (s *Context) ToggleFeatureFlag(name string) { + require.Nil(s.TestCtrl, s.AdminClient().ToggleFeatureFlag(name)) +} + func (s *Context) SendFileIngest(fixtures []string) { apiClient := s.AdminClient() @@ -136,10 +140,10 @@ func (s *Context) WaitForDatapipeAnalysis(timeout time.Duration, originalWrapper type IngestAssertion func(testCtrl test.Controller, tx graph.Transaction) func (s *Context) AssertIngest(assertion IngestAssertion) { - graphDB := integration.OpenNeo4jGraphDB(s.TestCtrl) - defer graphDB.Close() + graphDB := integration.OpenGraphDB(s.TestCtrl) + defer graphDB.Close(s.ctx) - require.Nil(s.TestCtrl, graphDB.ReadTransaction(s.Ctx, func(tx graph.Transaction) error { + require.Nil(s.TestCtrl, graphDB.ReadTransaction(s.ctx, func(tx graph.Transaction) error { assertion(s.TestCtrl, tx) return nil }), "Unexpected database error during reconciliation assertion") diff --git a/cmd/api/src/api/v2/integration/reconciliation.go b/cmd/api/src/api/v2/integration/reconciliation.go index e8b5d2ec5f..f7e767a579 100644 --- a/cmd/api/src/api/v2/integration/reconciliation.go +++ b/cmd/api/src/api/v2/integration/reconciliation.go @@ -1,35 +1,35 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package integration import ( + "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/src/test" "github.com/specterops/bloodhound/src/test/integration" "github.com/stretchr/testify/require" - "github.com/specterops/bloodhound/dawgs/graph" ) type ReconciliationAssertion func(testCtrl test.Controller, tx graph.Transaction) func (s *Context) AssertReconciliation(assertion ReconciliationAssertion) { - graphDB := integration.OpenNeo4jGraphDB(s.TestCtrl) - defer graphDB.Close() + graphDB := integration.OpenGraphDB(s.TestCtrl) + defer graphDB.Close(s.ctx) - require.Nil(s.TestCtrl, graphDB.ReadTransaction(s.Ctx, func(tx graph.Transaction) error { + require.Nil(s.TestCtrl, graphDB.ReadTransaction(s.ctx, func(tx graph.Transaction) error { assertion(s.TestCtrl, tx) return nil }), "Unexpected database error during reconciliation assertion") diff --git a/cmd/api/src/api/v2/model.go b/cmd/api/src/api/v2/model.go index eb40000970..2d5cc66ee4 100644 --- a/cmd/api/src/api/v2/model.go +++ b/cmd/api/src/api/v2/model.go @@ -1,32 +1,32 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package v2 import ( + "github.com/gorilla/schema" + "github.com/specterops/bloodhound/cache" + _ "github.com/specterops/bloodhound/dawgs/drivers/neo4j" + "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/src/config" "github.com/specterops/bloodhound/src/daemons/datapipe" "github.com/specterops/bloodhound/src/database" "github.com/specterops/bloodhound/src/model" "github.com/specterops/bloodhound/src/queries" "github.com/specterops/bloodhound/src/serde" - "github.com/specterops/bloodhound/cache" - "github.com/gorilla/schema" - _ "github.com/specterops/bloodhound/dawgs/drivers/neo4j" - "github.com/specterops/bloodhound/dawgs/graph" ) type ListPermissionsResponse struct { @@ -144,16 +144,19 @@ type Resources struct { } func NewResources( - db database.Database, graphDB graph.Database, cfg config.Configuration, - apiCache cache.Cache, graphQueryCache cache.Cache, + rdms database.Database, + graphDB *graph.DatabaseSwitch, + cfg config.Configuration, + apiCache cache.Cache, + graphQuery queries.Graph, collectorManifests config.CollectorManifests, taskNotifier datapipe.Tasker, ) Resources { return Resources{ Decoder: schema.NewDecoder(), - DB: db, + DB: rdms, Graph: graphDB, // TODO: to be phased out in favor of graph queries - GraphQuery: queries.NewGraphQuery(graphDB, graphQueryCache, cfg.SlowQueryThreshold, cfg.DisableCypherQC), + GraphQuery: graphQuery, Config: cfg, QueryParameterFilterParser: model.NewQueryParameterFilterParser(), Cache: apiCache, diff --git a/cmd/api/src/auth/model.go b/cmd/api/src/auth/model.go index 0e6e62799d..5a009dcff4 100644 --- a/cmd/api/src/auth/model.go +++ b/cmd/api/src/auth/model.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package auth @@ -23,11 +23,11 @@ import ( "strconv" "time" - "github.com/specterops/bloodhound/src/database/types/null" - "github.com/specterops/bloodhound/src/model" "github.com/gofrs/uuid" "github.com/golang-jwt/jwt/v4" "github.com/specterops/bloodhound/errors" + "github.com/specterops/bloodhound/src/database/types/null" + "github.com/specterops/bloodhound/src/model" ) const ( @@ -85,7 +85,8 @@ func (s idResolver) GetIdentity(ctx Context) (SimpleIdentity, error) { type Authorizer interface { AllowsPermission(ctx Context, requiredPermission model.Permission) bool - AllowsPermissions(ctx Context, requiredPermissions model.Permissions) bool + AllowsAllPermissions(ctx Context, requiredPermissions model.Permissions) bool + AllowsAtLeastOnePermission(ctx Context, requiredPermissions model.Permissions) bool } type authorizer struct{} @@ -106,7 +107,7 @@ func (s authorizer) AllowsPermission(ctx Context, requiredPermission model.Permi return false } -func (s authorizer) AllowsPermissions(ctx Context, requiredPermissions model.Permissions) bool { +func (s authorizer) AllowsAllPermissions(ctx Context, requiredPermissions model.Permissions) bool { for _, permission := range requiredPermissions { if !s.AllowsPermission(ctx, permission) { return false @@ -116,6 +117,16 @@ func (s authorizer) AllowsPermissions(ctx Context, requiredPermissions model.Per return true } +func (s authorizer) AllowsAtLeastOnePermission(ctx Context, requiredPermissions model.Permissions) bool { + for _, permission := range requiredPermissions { + if s.AllowsPermission(ctx, permission) { + return true + } + } + + return false +} + type Context struct { PermissionOverrides PermissionOverrides Owner any diff --git a/cmd/api/src/auth/permission.go b/cmd/api/src/auth/permission.go index 62221a1b83..af2dcb78d5 100644 --- a/cmd/api/src/auth/permission.go +++ b/cmd/api/src/auth/permission.go @@ -44,6 +44,8 @@ type PermissionSet struct { SavedQueriesRead model.Permission SavedQueriesWrite model.Permission + + ClientsRead model.Permission } func (s PermissionSet) All() model.Permissions { @@ -64,6 +66,7 @@ func (s PermissionSet) All() model.Permissions { s.APsManageAPs, s.SavedQueriesRead, s.SavedQueriesWrite, + s.ClientsRead, } } @@ -92,5 +95,7 @@ func Permissions() PermissionSet { SavedQueriesRead: model.NewPermission("saved_queries", "Read"), SavedQueriesWrite: model.NewPermission("saved_queries", "Write"), + + ClientsRead: model.NewPermission("clients", "Read"), } } diff --git a/cmd/api/src/auth/role.go b/cmd/api/src/auth/role.go index 47522d90b5..7a56745ef1 100644 --- a/cmd/api/src/auth/role.go +++ b/cmd/api/src/auth/role.go @@ -89,13 +89,13 @@ func Roles() map[string]RoleTemplate { Description: "Can read data, modify asset group memberships", Permissions: model.Permissions{ permissions.GraphDBRead, - permissions.ClientsManage, permissions.AuthCreateToken, permissions.AuthManageSelf, permissions.APsGenerateReport, permissions.AppReadApplicationConfiguration, permissions.SavedQueriesRead, permissions.SavedQueriesWrite, + permissions.ClientsRead, }, }, RoleAdministrator: { diff --git a/cmd/api/src/bootstrap/initializer.go b/cmd/api/src/bootstrap/initializer.go new file mode 100644 index 0000000000..5faa4a361c --- /dev/null +++ b/cmd/api/src/bootstrap/initializer.go @@ -0,0 +1,83 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package bootstrap + +import ( + "context" + "fmt" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/log" + "github.com/specterops/bloodhound/src/config" + "github.com/specterops/bloodhound/src/daemons" + "github.com/specterops/bloodhound/src/database" +) + +type DatabaseConnections[DBType database.Database, GraphType graph.Database] struct { + RDMS DBType + Graph GraphType +} + +type DatabaseConstructor[DBType database.Database, GraphType graph.Database] func(ctx context.Context, cfg config.Configuration) (DatabaseConnections[DBType, GraphType], error) +type InitializerLogic[DBType database.Database, GraphType graph.Database] func(ctx context.Context, cfg config.Configuration, databaseConnections DatabaseConnections[DBType, GraphType]) ([]daemons.Daemon, error) + +type Initializer[DBType database.Database, GraphType graph.Database] struct { + Configuration config.Configuration + Entrypoint InitializerLogic[DBType, GraphType] + DBConnector DatabaseConstructor[DBType, GraphType] +} + +func (s Initializer[DBType, GraphType]) Launch(parentCtx context.Context, handleSignals bool) error { + var ( + ctx = parentCtx + daemonManager = daemons.NewManager(DefaultServerShutdownTimeout) + ) + + if handleSignals { + ctx = NewDaemonContext(parentCtx) + } + + if err := InitializeLogging(s.Configuration); err != nil { + return fmt.Errorf("log initialization error: %w", err) + } + + if err := EnsureServerDirectories(s.Configuration); err != nil { + return fmt.Errorf("failed to ensure server directories: %w", err) + } + + if databaseConnections, err := s.DBConnector(ctx, s.Configuration); err != nil { + return fmt.Errorf("failed to connect to databases: %w", err) + } else if daemonInstances, err := s.Entrypoint(ctx, s.Configuration, databaseConnections); err != nil { + return fmt.Errorf("failed to start services: %w", err) + } else { + // Ensure that the database instances are closed once we're ready to exit regardless of p + defer databaseConnections.RDMS.Close() + defer databaseConnections.Graph.Close(context.Background()) + + daemonManager.Start(daemonInstances...) + } + + // Log successful start and wait for a signal to exit + log.Infof("Server started successfully") + <-ctx.Done() + + log.Infof("Shutting down") + + // TODO: Refactor this pattern in favor of context handling + daemonManager.Stop() + + return nil +} diff --git a/cmd/api/src/server/server.go b/cmd/api/src/bootstrap/server.go similarity index 52% rename from cmd/api/src/server/server.go rename to cmd/api/src/bootstrap/server.go index 6a37b3c0f9..f8de3fa299 100644 --- a/cmd/api/src/server/server.go +++ b/cmd/api/src/bootstrap/server.go @@ -14,9 +14,10 @@ // // SPDX-License-Identifier: Apache-2.0 -package server +package bootstrap import ( + "context" "fmt" "os" "os/signal" @@ -25,19 +26,10 @@ import ( "time" iso8601 "github.com/channelmeter/iso8601duration" - "github.com/specterops/bloodhound/cache" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/log" - "github.com/specterops/bloodhound/src/api" - "github.com/specterops/bloodhound/src/api/registration" - "github.com/specterops/bloodhound/src/api/router" "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/src/config" - "github.com/specterops/bloodhound/src/daemons" - "github.com/specterops/bloodhound/src/daemons/api/bhapi" - "github.com/specterops/bloodhound/src/daemons/api/toolapi" - "github.com/specterops/bloodhound/src/daemons/datapipe" - "github.com/specterops/bloodhound/src/daemons/gc" "github.com/specterops/bloodhound/src/database" "github.com/specterops/bloodhound/src/database/types/null" "github.com/specterops/bloodhound/src/migrations" @@ -50,11 +42,12 @@ const ( ContentSecurityPolicy = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self' data:;" ) -// SystemSignalExitChannel is used to shut down the server. It creates a channel that listens for an exit signal from the server. -func SystemSignalExitChannel() chan struct{} { - exitC := make(chan struct{}) +func NewDaemonContext(parentCtx context.Context) context.Context { + daemonContext, doneFunc := context.WithCancel(parentCtx) go func() { + defer doneFunc() + // Shutdown on SIGINT/SIGTERM signalChannel := make(chan os.Signal, 1) signal.Notify(signalChannel, syscall.SIGTERM) @@ -62,28 +55,18 @@ func SystemSignalExitChannel() chan struct{} { // Wait for a signal from the OS <-signalChannel - close(exitC) }() - return exitC + return daemonContext } // MigrateGraph runs migrations for the graph database -func MigrateGraph(cfg config.Configuration, db graph.Database) error { - if cfg.DisableMigrations { - log.Infof("Graph migrations are disabled per configuration") - return nil - } - return migrations.NewGraphMigrator(db).Migrate() +func MigrateGraph(ctx context.Context, db graph.Database, schema graph.Schema) error { + return migrations.NewGraphMigrator(db).Migrate(ctx, schema) } // MigrateDB runs database migrations on PG func MigrateDB(cfg config.Configuration, db database.Database) error { - if cfg.DisableMigrations { - log.Infof("Database migrations are disabled per configuration") - return nil - } - if err := db.Migrate(); err != nil { return err } @@ -145,56 +128,3 @@ func MigrateDB(cfg config.Configuration, db database.Database) error { return nil } - -// StartServer sets up background daemons, runs the service and waits for an exit signal to shut it down -func StartServer(cfg config.Configuration, exitC chan struct{}) error { - if err := InitializeLogging(cfg); err != nil { - return fmt.Errorf("log initialization error: %w", err) - } - - if db, graphDB, err := ConnectDatabases(cfg); err != nil { - return fmt.Errorf("db connection error: %w", err) - } else if err := MigrateDB(cfg, db); err != nil { - return fmt.Errorf("db migration error: %w", err) - } else if err := MigrateGraph(cfg, graphDB); err != nil { - return fmt.Errorf("graph db migration error: %w", err) - } else if apiCache, err := cache.NewCache(cache.Config{MaxSize: cfg.MaxAPICacheSize}); err != nil { - return fmt.Errorf("failed to create in-memory cache for API: %w", err) - } else if graphQueryCache, err := cache.NewCache(cache.Config{MaxSize: cfg.MaxAPICacheSize}); err != nil { - return fmt.Errorf("failed to create in-memory cache for graph queries: %w", err) - } else if collectorManifests, err := cfg.SaveCollectorManifests(); err != nil { - return fmt.Errorf("failed to save collector manifests: %w", err) - } else { - var ( - serviceManager = daemons.NewManager(DefaultServerShutdownTimeout) - sessionSweepingService = gc.NewDataPruningDaemon(db) - routerInst = router.NewRouter(cfg, auth.NewAuthorizer(), ContentSecurityPolicy) - toolingService = toolapi.NewDaemon(cfg, db) - datapipeDaemon = datapipe.NewDaemon(cfg, db, graphDB, graphQueryCache, time.Duration(cfg.DatapipeInterval)*time.Second) - authenticator = api.NewAuthenticator(cfg, db, database.NewContextInitializer(db)) - ) - - registration.RegisterFossGlobalMiddleware(&routerInst, cfg, auth.NewIdentityResolver(), authenticator) - registration.RegisterFossRoutes(&routerInst, cfg, db, graphDB, apiCache, graphQueryCache, collectorManifests, authenticator, datapipeDaemon) - apiDaemon := bhapi.NewDaemon(cfg, routerInst.Handler()) - - // Set neo4j batch and flush sizes - neo4jParameters := appcfg.GetNeo4jParameters(db) - graphDB.SetBatchWriteSize(neo4jParameters.BatchWriteSize) - graphDB.SetWriteFlushSize(neo4jParameters.WriteFlushSize) - - // Start daemons - serviceManager.Start(apiDaemon, toolingService, sessionSweepingService, datapipeDaemon) - - log.Infof("Server started successfully") - // Wait for a signal to exit - <-exitC - - log.Infof("Shutting down") - serviceManager.Stop() - - log.Infof("Server shut down successfully") - } - - return nil -} diff --git a/cmd/api/src/server/util.go b/cmd/api/src/bootstrap/util.go similarity index 61% rename from cmd/api/src/server/util.go rename to cmd/api/src/bootstrap/util.go index 6e12a87c29..14fd4e6709 100644 --- a/cmd/api/src/server/util.go +++ b/cmd/api/src/bootstrap/util.go @@ -14,21 +14,21 @@ // // SPDX-License-Identifier: Apache-2.0 -package server +package bootstrap import ( + "context" "fmt" - "os" - "path/filepath" - "github.com/specterops/bloodhound/dawgs" + "github.com/specterops/bloodhound/dawgs/drivers/neo4j" _ "github.com/specterops/bloodhound/dawgs/drivers/neo4j" + "github.com/specterops/bloodhound/dawgs/drivers/pg" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/dawgs/util/size" "github.com/specterops/bloodhound/log" - "github.com/specterops/bloodhound/src/auth" + "github.com/specterops/bloodhound/src/api/tools" "github.com/specterops/bloodhound/src/config" - "github.com/specterops/bloodhound/src/database" + "os" ) func ensureDirectory(path string) error { @@ -67,49 +67,40 @@ func EnsureServerDirectories(cfg config.Configuration) error { return nil } -func mustGetWorkingDirectory() string { - workingDirectory, err := os.Getwd() - - if err != nil { - fmt.Printf("Unable to lookup working directory: %v", err) - os.Exit(1) - } - - return workingDirectory -} - // DefaultConfigFilePath returns the location of the config file func DefaultConfigFilePath() string { return "/etc/bhapi/bhapi.json" } -// DefaultWorkDirPath returns the default location of the working directory -func DefaultWorkDirPath() string { - return filepath.Join(mustGetWorkingDirectory(), "work") -} +func ConnectGraph(ctx context.Context, cfg config.Configuration) (*graph.DatabaseSwitch, error) { + var connectionString string -// ConnectPostgres initializes a connection to PG, and returns errors if any -func ConnectPostgres(cfg config.Configuration) (*database.BloodhoundDB, error) { - if db, err := database.OpenDatabase(cfg.Database.PostgreSQLConnectionString()); err != nil { - return nil, fmt.Errorf("error while attempting to create database connection: %w", err) + if driverName, err := tools.LookupGraphDriver(ctx, cfg); err != nil { + return nil, err } else { - return database.NewBloodhoundDB(db, auth.NewIdentityResolver()), nil - } -} + switch driverName { + case neo4j.DriverName: + log.Infof("Connecting to graph using Neo4j") + connectionString = cfg.Neo4J.Neo4jConnectionString() -// ConnectDatabases initializes connections to PG and connection, and returns errors if any -func ConnectDatabases(cfg config.Configuration) (*database.BloodhoundDB, graph.Database, error) { - dawgsCfg := dawgs.Config{ - DriverCfg: cfg.Neo4J.Neo4jConnectionString(), - TraversalMemoryLimit: size.Size(cfg.TraversalMemoryLimit) * size.Gibibyte, - } + case pg.DriverName: + log.Infof("Connecting to graph using PostgreSQL") + connectionString = cfg.Database.PostgreSQLConnectionString() - if db, err := ConnectPostgres(cfg); err != nil { - return nil, nil, err - } else if graphDatabase, err := dawgs.Open("neo4j", dawgsCfg); err != nil { - return nil, nil, err - } else { - return db, graphDatabase, nil + default: + return nil, fmt.Errorf("unknown graphdb driver name: %s", driverName) + } + + if connectionString == "" { + return nil, fmt.Errorf("graph connection requires a connection url to be set") + } else if graphDatabase, err := dawgs.Open(ctx, driverName, dawgs.Config{ + TraversalMemoryLimit: size.Size(cfg.TraversalMemoryLimit) * size.Gibibyte, + DriverCfg: connectionString, + }); err != nil { + return nil, err + } else { + return graph.NewDatabaseSwitch(ctx, graphDatabase), nil + } } } diff --git a/cmd/api/src/cmd/bhapi/main.go b/cmd/api/src/cmd/bhapi/main.go index e283ff020c..3de6b647d3 100644 --- a/cmd/api/src/cmd/bhapi/main.go +++ b/cmd/api/src/cmd/bhapi/main.go @@ -17,15 +17,17 @@ package main import ( + "context" "flag" "fmt" - "os" - + "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/log" + "github.com/specterops/bloodhound/src/bootstrap" "github.com/specterops/bloodhound/src/config" - "github.com/specterops/bloodhound/src/migrations" - "github.com/specterops/bloodhound/src/server" + "github.com/specterops/bloodhound/src/database" + "github.com/specterops/bloodhound/src/services" "github.com/specterops/bloodhound/src/version" + "os" // This import is required by swaggo _ "github.com/specterops/bloodhound/src/docs" @@ -36,26 +38,10 @@ func printVersion() { os.Exit(0) } -func performMigrationsOnly(cfg config.Configuration) { - if db, graphDB, err := server.ConnectDatabases(cfg); err != nil { - log.Fatalf("Failed connecting to databases: %v", err) - } else if err := db.Migrate(); err != nil { - log.Fatalf("Migrations failed: %v", err) - } else { - var migrator = migrations.NewGraphMigrator(graphDB) - if err := migrator.Migrate(); err != nil { - log.Fatalf("Error running migrations for graph db: %v", err) - } - } - - fmt.Println("Migrations executed successfully") -} - func main() { var ( configFilePath string logFilePath string - migrationFlag bool versionFlag bool ) @@ -64,9 +50,8 @@ func main() { flag.PrintDefaults() } - flag.BoolVar(&migrationFlag, "migrate", false, "Only perform database migrations. Do not start the server.") flag.BoolVar(&versionFlag, "version", false, "Get binary version.") - flag.StringVar(&configFilePath, "configfile", server.DefaultConfigFilePath(), "Configuration file to load.") + flag.StringVar(&configFilePath, "configfile", bootstrap.DefaultConfigFilePath(), "Configuration file to load.") flag.StringVar(&logFilePath, "logfile", config.DefaultLogFilePath, "Log file to write to.") flag.Parse() @@ -77,13 +62,17 @@ func main() { // Initialize basic logging facilities while we start up log.ConfigureDefaults() - if cfg, err := config.GetConfiguration(configFilePath); err != nil { + if cfg, err := config.GetConfiguration(configFilePath, config.NewDefaultConfiguration); err != nil { log.Fatalf("Unable to read configuration %s: %v", configFilePath, err) - } else if err := server.EnsureServerDirectories(cfg); err != nil { - log.Fatalf("Fatal error while attempting to ensure working directories: %v", err) - } else if migrationFlag { - performMigrationsOnly(cfg) - } else if err := server.StartServer(cfg, server.SystemSignalExitChannel()); err != nil { - log.Fatalf("Server start error: %v", err) + } else { + initializer := bootstrap.Initializer[*database.BloodhoundDB, *graph.DatabaseSwitch]{ + Configuration: cfg, + DBConnector: services.ConnectDatabases, + Entrypoint: services.Entrypoint, + } + + if err := initializer.Launch(context.Background(), true); err != nil { + log.Fatalf("Failed starting the server: %v", err) + } } } diff --git a/cmd/api/src/cmd/dawgs-harness/main.go b/cmd/api/src/cmd/dawgs-harness/main.go index 1f75ea82b5..923b302357 100644 --- a/cmd/api/src/cmd/dawgs-harness/main.go +++ b/cmd/api/src/cmd/dawgs-harness/main.go @@ -20,17 +20,20 @@ import ( "context" "flag" "fmt" - "net/http" + "github.com/specterops/bloodhound/dawgs/drivers/neo4j" + "github.com/specterops/bloodhound/dawgs/drivers/pg" + "github.com/specterops/bloodhound/dawgs/util/size" + schema "github.com/specterops/bloodhound/graphschema" _ "net/http/pprof" "os" + "os/signal" + "runtime/pprof" + "syscall" "time" "github.com/jedib0t/go-pretty/v6/table" "github.com/specterops/bloodhound/dawgs" - "github.com/specterops/bloodhound/dawgs/drivers/neo4j" "github.com/specterops/bloodhound/dawgs/graph" - "github.com/specterops/bloodhound/graphschema/ad" - "github.com/specterops/bloodhound/graphschema/common" "github.com/specterops/bloodhound/log" "github.com/specterops/bloodhound/src/cmd/dawgs-harness/tests" ) @@ -40,66 +43,117 @@ func fatalf(format string, args ...any) { os.Exit(1) } -func RunNeo4jTestSuite(dbHost string) tests.TestSuite { - if connection, err := dawgs.Open("neo4j", dawgs.Config{DriverCfg: fmt.Sprintf("neo4j://neo4j:neo4jj@%s:7687/neo4j", dbHost)}); err != nil { - fatalf("Failed opening neo4j: %v", err) - } else if err := connection.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - return tx.Nodes().Delete() +func RunTestSuite(ctx context.Context, connectionStr, driverName string) tests.TestSuite { + if connection, err := dawgs.Open(context.TODO(), driverName, dawgs.Config{ + TraversalMemoryLimit: size.Gibibyte, + DriverCfg: connectionStr, }); err != nil { - fatalf("Failed to clear neo4j: %v", err) - } else if err := neo4j.AssertNodePropertyIndex(connection, ad.Entity, common.Name.String(), graph.BTreeIndex); err != nil { - fatalf("Error creating database schema: %v", err) - } else if testSuite, err := tests.RunSuite(tests.Neo4j, connection); err != nil { - fatalf("Test suite error: %v", err) + fatalf("Failed opening %s database: %v", driverName, err) } else { - connection.Close() - return testSuite + defer connection.Close(ctx) + + if err := connection.AssertSchema(ctx, schema.DefaultGraphSchema()); err != nil { + fatalf("Failed asserting graph schema on %s database: %v", driverName, err) + } else if err := connection.WriteTransaction(ctx, func(tx graph.Transaction) error { + return tx.Nodes().Delete() + }); err != nil { + fatalf("Failed to clear %s database: %v", driverName, err) + } else if testSuite, err := tests.RunSuite(connection, driverName); err != nil { + fatalf("Test suite error for %s database: %v", driverName, err) + } else { + return testSuite + } } - panic("") + panic(nil) } -func main() { +func newContext() context.Context { var ( - dbHost string - testType string - enablePprof bool + ctx, done = context.WithCancel(context.Background()) + sigchnl = make(chan os.Signal) ) - flag.StringVar(&testType, "test", "both", "Test to run. Must be one of: 'postgres', 'neo4j', 'both'") - flag.BoolVar(&enablePprof, "enable-pprof", false, "Enable the pprof HTTP sampling server.") - flag.IntVar(&tests.SimpleRelationshipsToCreate, "num-rels", 5000, "Number of simple relationships to create.") - flag.StringVar(&dbHost, "db-host", "192.168.122.170", "Database host.") - flag.Parse() + signal.Notify(sigchnl) + + go func() { + defer done() + + for nextSignal := range sigchnl { + switch nextSignal { + case syscall.SIGINT, syscall.SIGTERM: + return + } + } + }() + return ctx +} + +var enablePprof bool + +func execSuite(name string, logic func() tests.TestSuite) tests.TestSuite { if enablePprof { - go func() { - if err := http.ListenAndServe("localhost:8080", nil); err != nil { - log.Error().Fault(err).Msg("HTTP server caught an error while running.") + if cpuProfileFile, err := os.OpenFile(name+".pprof", syscall.O_WRONLY|syscall.O_TRUNC|syscall.O_CREAT, 0644); err != nil { + fatalf("Unable to open file for CPU profile: %v", err) + } else { + defer cpuProfileFile.Close() + + if err := pprof.StartCPUProfile(cpuProfileFile); err != nil { + fatalf("Failed to start CPU profile: %v", err) + } else { + defer pprof.StopCPUProfile() } - }() + } } + return logic() +} + +func main() { + var ( + ctx = newContext() + neo4jConnectionStr string + pgConnectionStr string + testType string + ) + + flag.StringVar(&testType, "test", "both", "Test to run. Must be one of: 'postgres', 'neo4j', 'both'") + flag.BoolVar(&enablePprof, "enable-pprof", true, "Enable the pprof HTTP sampling server.") + flag.IntVar(&tests.SimpleRelationshipsToCreate, "num-rels", 2000, "Number of simple relationships to create.") + flag.StringVar(&neo4jConnectionStr, "neo4j", "neo4j://neo4j:neo4jj@localhost:7687", "Neo4j connection string.") + flag.StringVar(&pgConnectionStr, "pg", "user=bhe dbname=bhe password=bhe4eva host=localhost", "PostgreSQL connection string.") + flag.Parse() + log.ConfigureDefaults() switch testType { - //case "both": - // pgTestSuite := RunPostgresqlTestSuite(dbHost) - // - // // Sleep between tests - // time.Sleep(time.Second * 3) - // fmt.Println() - // - // n4jTestSuite := RunNeo4jTestSuite(dbHost) - // fmt.Println() - // - // OutputTestSuiteDeltas(pgTestSuite, n4jTestSuite) - // - //case "postgres": - // RunPostgresqlTestSuite(dbHost) + case "both": + n4jTestSuite := execSuite(neo4j.DriverName, func() tests.TestSuite { + return RunTestSuite(ctx, neo4jConnectionStr, neo4j.DriverName) + }) + + fmt.Println() + + // Sleep between tests + time.Sleep(time.Second * 3) + + pgTestSuite := execSuite(pg.DriverName, func() tests.TestSuite { + return RunTestSuite(ctx, pgConnectionStr, pg.DriverName) + }) + fmt.Println() + + OutputTestSuiteDeltas(pgTestSuite, n4jTestSuite) + + case "postgres": + execSuite(pg.DriverName, func() tests.TestSuite { + return RunTestSuite(ctx, pgConnectionStr, pg.DriverName) + }) case "neo4j": - RunNeo4jTestSuite(dbHost) + execSuite(neo4j.DriverName, func() tests.TestSuite { + return RunTestSuite(ctx, neo4jConnectionStr, neo4j.DriverName) + }) } } diff --git a/cmd/api/src/cmd/dawgs-harness/tests/case.go b/cmd/api/src/cmd/dawgs-harness/tests/case.go index a304344877..1ada23d41c 100644 --- a/cmd/api/src/cmd/dawgs-harness/tests/case.go +++ b/cmd/api/src/cmd/dawgs-harness/tests/case.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package tests @@ -26,24 +26,6 @@ import ( "github.com/specterops/bloodhound/dawgs/graph" ) -type DBType int - -func (s DBType) String() string { - switch s { - case Neo4j: - return "neo4j" - case Postgres: - return "postgresql" - default: - panic(fmt.Sprintf("unknown DB type: %d", s)) - } -} - -const ( - Neo4j DBType = iota - Postgres -) - type TestDelegate func(testCase *TestCase) any type Sample struct { @@ -92,9 +74,8 @@ func (s Samples) LongestDuration() time.Duration { } type TestSuite struct { - Name string - DBType DBType - Cases []*TestCase + Name string + Cases []*TestCase } func (s *TestSuite) GetTestCase(testName string) *TestCase { @@ -110,7 +91,6 @@ func (s *TestSuite) GetTestCase(testName string) *TestCase { func (s *TestSuite) NewTestCase(testName string, delegate TestDelegate) { s.Cases = append(s.Cases, &TestCase{ Name: testName, - DBType: s.DBType, Delegate: delegate, }) } @@ -160,7 +140,6 @@ func (s *TestSuite) Execute(db graph.Database) error { type TestCase struct { Name string - DBType DBType Delegate TestDelegate Samples Samples Duration time.Duration diff --git a/cmd/api/src/cmd/dawgs-harness/tests/suite.go b/cmd/api/src/cmd/dawgs-harness/tests/suite.go index 5457849e64..450e246d15 100644 --- a/cmd/api/src/cmd/dawgs-harness/tests/suite.go +++ b/cmd/api/src/cmd/dawgs-harness/tests/suite.go @@ -1,27 +1,34 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package tests import ( + "context" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/graphschema/common" ) -func RunSuite(dbType DBType, db graph.Database) (TestSuite, error) { +func RunSuite(db graph.Database, driverName string) (TestSuite, error) { + if err := db.WriteTransaction(context.Background(), func(tx graph.Transaction) error { + return tx.Nodes().Delete() + }); err != nil { + return TestSuite{}, err + } + // Clear IDs StartNodeIDs = make([]graph.ID, SimpleRelationshipsToCreate) EndNodeIDs = make([]graph.ID, SimpleRelationshipsToCreate) @@ -29,27 +36,26 @@ func RunSuite(dbType DBType, db graph.Database) (TestSuite, error) { // Setup and run the test suite suite := TestSuite{ - Name: dbType.String(), - DBType: dbType, + Name: driverName, } suite.NewTestCase("Node and Relationship Creation", NodeAndRelationshipCreationTest) suite.NewTestCase("Batch Node and Relationship Creation", BatchNodeAndRelationshipCreationTest) suite.NewTestCase("Fetch Nodes by ID", FetchNodesByID) - //suite.NewTestCase("Fetch Nodes by Filter Property", FetchNodesByProperty(ad.ObjectID)) - suite.NewTestCase("Fetch Nodes by Indexed Property", FetchNodesByProperty(common.Name.String())) + suite.NewTestCase("Fetch Nodes by Filter Item", FetchNodesByProperty(common.ObjectID.String(), SimpleRelationshipsToCreate/4)) + suite.NewTestCase("Fetch Nodes by Indexed Item", FetchNodesByProperty(common.Name.String(), SimpleRelationshipsToCreate/4)) suite.NewTestCase("Fetch Nodes by Slice of Filter Properties", FetchNodesByPropertySlice(common.ObjectID.String())) suite.NewTestCase("Fetch Nodes by Slice of Indexed Properties", FetchNodesByPropertySlice(common.Name.String())) suite.NewTestCase("Node Update", NodeUpdateTests) suite.NewTestCase("Fetch Relationships by ID", FetchRelationshipsByID) - //suite.NewTestCase("Fetch Relationships by Filter Property", FetchRelationshipsByProperty(common.Name.String())) + suite.NewTestCase("Fetch Relationships by Filter Item", FetchRelationshipsByProperty(common.Name.String())) suite.NewTestCase("Fetch Relationships by Slice of Filter Properties", FetchRelationshipsByPropertySlice) - suite.NewTestCase("Fetch Relationships by Indexed Start Node Property", FetchRelationshipByStartNodeProperty) + suite.NewTestCase("Fetch Relationships by Indexed Start Node Item", FetchRelationshipByStartNodeProperty) - suite.NewTestCase("Fetch Directional Result by Indexed Start Node Property", FetchDirectionalResultByStartNodeProperty) + suite.NewTestCase("Fetch Directional Result by Indexed Start Node Item", FetchDirectionalResultByStartNodeProperty) suite.NewTestCase("Batch Delete Nodes by ID", BatchDeleteEndNodesByID) suite.NewTestCase("Delete Nodes by Slice of IDs", DeleteStartNodesByIDSlice) diff --git a/cmd/api/src/cmd/dawgs-harness/tests/tests.go b/cmd/api/src/cmd/dawgs-harness/tests/tests.go index 918cda82e2..06fba78317 100644 --- a/cmd/api/src/cmd/dawgs-harness/tests/tests.go +++ b/cmd/api/src/cmd/dawgs-harness/tests/tests.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package tests @@ -61,15 +61,12 @@ func FetchNodesByID(testCase *TestCase) any { } } -func FetchNodesByProperty(propertyName string) func(testCase *TestCase) any { +func FetchNodesByProperty(propertyName string, maxFetches int) func(testCase *TestCase) any { return func(testCase *TestCase) any { return func(tx graph.Transaction) error { - var ( - numExpectedFetches = len(StartNodeIDs) - resultsFetched = 0 - ) + resultsFetched := 0 - for iteration := 0; iteration < numExpectedFetches; iteration++ { + for iteration := 0; iteration < maxFetches; iteration++ { propertyValue := "batch start node " + strconv.Itoa(iteration) if iteration == 0 { @@ -80,10 +77,6 @@ func FetchNodesByProperty(propertyName string) func(testCase *TestCase) any { if err := testCase.Sample(func() error { return tx.Nodes().Filterf(func() graph.Criteria { - if testCase.DBType == Postgres { - return query.Equals(query.NodeProperty(propertyName), propertyValue) - } - return query.And( query.Kind(query.Node(), ad.Entity), query.Equals(query.NodeProperty(propertyName), propertyValue), @@ -106,7 +99,7 @@ func FetchNodesByProperty(propertyName string) func(testCase *TestCase) any { } } - return validateFetches(numExpectedFetches, resultsFetched) + return validateFetches(maxFetches, resultsFetched) } } } @@ -126,10 +119,6 @@ func FetchNodesByPropertySlice(propertyName string) func(testCase *TestCase) any if err := testCase.Sample(func() error { return tx.Nodes().Filterf(func() graph.Criteria { - if testCase.DBType == Postgres { - return query.In(query.NodeProperty(propertyName), propertyValues) - } - return query.And( query.Kind(query.Node(), ad.Entity), query.In(query.NodeProperty(propertyName), propertyValues), @@ -170,10 +159,6 @@ func FetchRelationshipByStartNodeProperty(testCase *TestCase) any { if err := testCase.Sample(func() error { return tx.Relationships().Filterf(func() graph.Criteria { - if testCase.DBType == Postgres { - return query.Equals(query.StartProperty(common.Name.String()), nodeName) - } - return query.And( query.Kind(query.Start(), ad.Entity), query.Equals(query.StartProperty(common.Name.String()), nodeName), @@ -222,10 +207,6 @@ func FetchDirectionalResultByStartNodeProperty(testCase *TestCase) any { if err := testCase.Sample(func() error { return tx.Relationships().Filterf(func() graph.Criteria { - if testCase.DBType == Postgres { - return query.Equals(query.StartProperty(common.Name.String()), nodeName) - } - return query.And( query.Kind(query.Start(), ad.Entity), query.Equals(query.StartProperty(common.Name.String()), nodeName), @@ -272,7 +253,7 @@ func FetchRelationshipsByPropertySlice(testCase *TestCase) any { if err := testCase.Sample(func() error { return tx.Relationships().Filterf(func() graph.Criteria { - return query.In(query.NodeProperty(common.Name.String()), relationshipNames) + return query.In(query.RelationshipProperty(common.Name.String()), relationshipNames) }).Fetch(func(cursor graph.Cursor[*graph.Relationship]) error { for relationship := range cursor.Chan() { if _, err := relationship.Properties.Get(common.Name.String()).String(); err != nil { @@ -305,7 +286,7 @@ func FetchRelationshipsByProperty(propertyName string) func(testCase *TestCase) if err := testCase.Sample(func() error { return tx.Relationships().Filterf(func() graph.Criteria { - return query.Equals(query.NodeProperty(propertyName), relationshipName) + return query.Equals(query.RelationshipProperty(propertyName), relationshipName) }).Fetch(func(cursor graph.Cursor[*graph.Relationship]) error { for relationship := range cursor.Chan() { if actualRelationshipName, err := relationship.Properties.Get(common.Name.String()).String(); err != nil { @@ -413,10 +394,18 @@ func BatchNodeAndRelationshipCreationTest(testCase *TestCase) any { ) if err := testCase.Sample(func() error { - return batch.CreateRelationship(startNode, endNode, ad.MemberOf, graph.AsProperties(graph.PropertyMap{ - common.Name: relationshipPropertyValue, - common.ObjectID: relationshipPropertyValue, - })) + return batch.UpdateRelationshipBy(graph.RelationshipUpdate{ + Relationship: graph.PrepareRelationship(graph.AsProperties(graph.PropertyMap{ + common.Name: relationshipPropertyValue, + common.ObjectID: relationshipPropertyValue, + }), ad.MemberOf), + Start: startNode, + StartIdentityKind: ad.Entity, + StartIdentityProperties: []string{common.ObjectID.String()}, + End: endNode, + EndIdentityKind: ad.Entity, + EndIdentityProperties: []string{common.ObjectID.String()}, + }) }); err != nil { return err } @@ -447,7 +436,7 @@ func NodeAndRelationshipCreationTest(testCase *TestCase) any { common.ObjectID: endNodePropertyValue, }), ad.Entity, ad.Group); err != nil { return err - } else if relationship, err := tx.CreateRelationship(startNode, endNode, ad.MemberOf, graph.AsProperties(graph.PropertyMap{ + } else if relationship, err := tx.CreateRelationshipByIDs(startNode.ID, endNode.ID, ad.MemberOf, graph.AsProperties(graph.PropertyMap{ common.Name: relationshipPropertyValue, common.ObjectID: relationshipPropertyValue, })); err != nil { diff --git a/cmd/api/src/config/config.go b/cmd/api/src/config/config.go index 8c29fb84ef..f956e4a299 100644 --- a/cmd/api/src/config/config.go +++ b/cmd/api/src/config/config.go @@ -151,6 +151,7 @@ type Configuration struct { LogLevel string `json:"log_level"` LogPath string `json:"log_path"` TLS TLSConfiguration `json:"tls"` + GraphDriver string `json:"graph_driver"` Database DatabaseConfiguration `json:"database"` Neo4J DatabaseConfiguration `json:"neo4j"` Crypto CryptoConfiguration `json:"crypto"` @@ -255,44 +256,51 @@ func SetValuesFromEnv(varPrefix string, target any, env []string) error { return nil } -func GetConfiguration(path string) (Configuration, error) { - cfg, err := NewDefaultConfiguration() - if err != nil { - return cfg, fmt.Errorf("failed to create default configuration: %w", err) - } - +func getConfiguration(path string, defaultConfigFunc func() (Configuration, error)) (Configuration, error) { if hasCfgFile, err := HasConfigurationFile(path); err != nil { return Configuration{}, err } else if hasCfgFile { log.Infof("Reading configuration found at %s", path) - if readCfg, err := ReadConfigurationFile(path); err != nil { - return Configuration{}, err - } else { - cfg = readCfg - } + return ReadConfigurationFile(path) } else { - log.Infof("No configuration file found at %s", path) - } + log.Infof("No configuration file found at %s. Returning defaults.", path) - if err := SetValuesFromEnv(bhAPIEnvironmentVariablePrefix, &cfg, os.Environ()); err != nil { - return Configuration{}, err + return defaultConfigFunc() } +} - return cfg, nil +func GetConfiguration(path string, defaultConfigFunc func() (Configuration, error)) (Configuration, error) { + if cfg, err := getConfiguration(path, defaultConfigFunc); err != nil { + return cfg, err + } else if err := SetValuesFromEnv(bhAPIEnvironmentVariablePrefix, &cfg, os.Environ()); err != nil { + return cfg, err + } else { + return cfg, nil + } } +const ( + azureHoundCollector = "azurehound" + sharpHoundCollector = "sharphound" +) + func (s Configuration) SaveCollectorManifests() (CollectorManifests, error) { - if azureHoundManifest, err := generateCollectorManifest(filepath.Join(s.CollectorsDirectory(), "azurehound")); err != nil { - return CollectorManifests{}, fmt.Errorf("error generating AzureHound manifest file: %w", err) - } else if sharpHoundManifest, err := generateCollectorManifest(filepath.Join(s.CollectorsDirectory(), "sharphound")); err != nil { - return CollectorManifests{}, fmt.Errorf("error generating SharpHound manifest file: %w", err) + manifests := CollectorManifests{} + + if azureHoundManifest, err := generateCollectorManifest(filepath.Join(s.CollectorsDirectory(), azureHoundCollector)); err != nil { + log.Errorf("error generating AzureHound manifest file: %s", err) } else { - return CollectorManifests{ - "azurehound": azureHoundManifest, - "sharphound": sharpHoundManifest, - }, nil + manifests[azureHoundCollector] = azureHoundManifest } + + if sharpHoundManifest, err := generateCollectorManifest(filepath.Join(s.CollectorsDirectory(), sharpHoundCollector)); err != nil { + log.Errorf("error generating SharpHound manifest file: %s", err) + } else { + manifests[sharpHoundCollector] = sharpHoundManifest + } + + return manifests, nil } func generateCollectorManifest(collectorDir string) (CollectorManifest, error) { @@ -317,7 +325,7 @@ func generateCollectorManifest(collectorDir string) (CollectorManifest, error) { collectorVersions = append(collectorVersions, CollectorVersion{ Version: string(version), SHA256Sum: strings.Fields(string(sha256))[0], // Get only the SHA-256 portion - Deprecated: strings.Contains(collectorDir, "sharphound") && string(version) < "v2.0.0", + Deprecated: strings.Contains(collectorDir, sharpHoundCollector) && string(version) < "v2.0.0", }) if string(version) > latestVersion { diff --git a/cmd/api/src/config/default.go b/cmd/api/src/config/default.go index 560626fca7..b3d726d367 100644 --- a/cmd/api/src/config/default.go +++ b/cmd/api/src/config/default.go @@ -18,6 +18,7 @@ package config import ( "fmt" + "github.com/specterops/bloodhound/dawgs/drivers/neo4j" "github.com/specterops/bloodhound/src/serde" ) @@ -52,6 +53,7 @@ func NewDefaultConfiguration() (Configuration, error) { TraversalMemoryLimit: 2, // 2 GiB by default TLS: TLSConfiguration{}, SAML: SAMLConfiguration{}, + GraphDriver: neo4j.DriverName, // Default to Neo4j as the graph driver Database: DatabaseConfiguration{ MaxConcurrentSessions: 10, }, diff --git a/cmd/api/src/daemons/api/toolapi/api.go b/cmd/api/src/daemons/api/toolapi/api.go index 61acf3cdaf..4b57a40d91 100644 --- a/cmd/api/src/daemons/api/toolapi/api.go +++ b/cmd/api/src/daemons/api/toolapi/api.go @@ -1,32 +1,34 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package toolapi import ( "context" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/src/bootstrap" + "github.com/specterops/bloodhound/src/database" "net/http" "time" - "github.com/specterops/bloodhound/src/api/tools" - "github.com/specterops/bloodhound/src/config" - "github.com/specterops/bloodhound/src/database" "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/specterops/bloodhound/log" + "github.com/specterops/bloodhound/src/api/tools" + "github.com/specterops/bloodhound/src/config" ) // Daemon holds data relevant to the tools API daemon @@ -35,16 +37,22 @@ type Daemon struct { server *http.Server } -func NewDaemon(cfg config.Configuration, db database.Database) Daemon { +func NewDaemon[DBType database.Database](ctx context.Context, connections bootstrap.DatabaseConnections[DBType, *graph.DatabaseSwitch], cfg config.Configuration, graphSchema graph.Schema, extensions ...func(router *chi.Mux)) Daemon { var ( networkTimeout = time.Duration(cfg.NetTimeoutSeconds) * time.Second + pgMigrator = tools.NewPGMigrator(ctx, cfg, graphSchema, connections.Graph) router = chi.NewRouter() - toolContainer = tools.NewToolContainer(db) + toolContainer = tools.NewToolContainer(connections.RDMS) ) router.Mount("/metrics", promhttp.Handler()) router.Get("/trace", tools.NewTraceHandler()) + router.Put("/graph-db/switch/pg", pgMigrator.SwitchPostgreSQL) + router.Put("/graph-db/switch/neo4j", pgMigrator.SwitchNeo4j) + router.Put("/pg-migration/start", pgMigrator.MigrationStart) + router.Get("/pg-migration/status", pgMigrator.MigrationStatus) + router.Put("/pg-migration/cancel", pgMigrator.MigrationCancel) router.Get("/logging", tools.GetLoggingDetails) router.Put("/logging", tools.PutLoggingDetails) @@ -52,6 +60,10 @@ func NewDaemon(cfg config.Configuration, db database.Database) Daemon { router.Get("/features", toolContainer.GetFlags) router.Put("/features/{feature_id:[0-9]+}/toggle", toolContainer.ToggleFlag) + for _, extension := range extensions { + extension(router) + } + return Daemon{ cfg: cfg, server: &http.Server{ diff --git a/cmd/api/src/daemons/datapipe/agi.go b/cmd/api/src/daemons/datapipe/agi.go index 387b801a38..78a9c8c65a 100644 --- a/cmd/api/src/daemons/datapipe/agi.go +++ b/cmd/api/src/daemons/datapipe/agi.go @@ -154,90 +154,26 @@ func ParallelTagAzureTierZero(ctx context.Context, db graph.Database) error { return nil } -func ParallelTagActiveDirectoryTierZero(ctx context.Context, db graph.Database) error { +func TagActiveDirectoryTierZero(ctx context.Context, db graph.Database) error { + defer log.Measure(log.LevelInfo, "Finished tagging Active Directory Tier Zero")() + if domains, err := adAnalysis.FetchAllDomains(ctx, db); err != nil { return err } else { - var ( - domainC = make(chan *graph.Node) - rootsC = make(chan graph.ID) - writerWG = &sync.WaitGroup{} - readerWG = &sync.WaitGroup{} - ) - - readerWG.Add(1) - - go func() { - defer readerWG.Done() - - var ( - tierZeroProperties = graph.NewProperties() - rootIDs []graph.ID - ) - - tierZeroProperties.Set(common.SystemTags.String(), ad.AdminTierZero) - - for rootID := range rootsC { - seen := false - - for _, seenRootID := range rootIDs { - if seenRootID == rootID { - seen = true - break - } - } - - if !seen { - rootIDs = append(rootIDs, rootID) - } - } - - if err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { - if err := tx.Nodes().Filterf(func() graph.Criteria { - return query.InIDs(query.NodeID(), rootIDs...) - }).Update(tierZeroProperties); err != nil { + for _, domain := range domains { + if roots, err := adAnalysis.FetchActiveDirectoryTierZeroRoots(ctx, db, domain); err != nil { + return err + } else { + properties := graph.NewProperties() + properties.Set(common.SystemTags.String(), ad.AdminTierZero) + + if err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + return tx.Nodes().Filter(query.InIDs(query.Node(), roots.IDs()...)).Update(properties) + }); err != nil { return err } - - return nil - }); err != nil { - log.Errorf("Failed tagging update: %v", err) } - }() - - for workerID := 0; workerID < commonanalysis.MaximumDatabaseParallelWorkers; workerID++ { - writerWG.Add(1) - - go func(workerID int) { - defer writerWG.Done() - - if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { - for domain := range domainC { - if roots, err := adAnalysis.FetchActiveDirectoryTierZeroRoots(tx, domain); err != nil { - log.Errorf("Failed fetching tier zero for domain %d: %v", domain.ID, err) - } else { - for _, root := range roots { - rootsC <- root.ID - } - } - } - - return nil - }); err != nil { - log.Errorf("Error reading tier zero for domains: %v", err) - } - }(workerID) } - - for _, domain := range domains { - domainC <- domain - } - - close(domainC) - writerWG.Wait() - - close(rootsC) - readerWG.Wait() } return nil diff --git a/cmd/api/src/daemons/datapipe/analysis.go b/cmd/api/src/daemons/datapipe/analysis.go index 1967d908d2..3ca3be73f9 100644 --- a/cmd/api/src/daemons/datapipe/analysis.go +++ b/cmd/api/src/daemons/datapipe/analysis.go @@ -54,7 +54,7 @@ func RunAnalysisOperations(ctx context.Context, db database.Database, graphDB gr collector.Collect(fmt.Errorf("asset group isolation tagging failed: %w", err)) } - if err := ParallelTagActiveDirectoryTierZero(ctx, graphDB); err != nil { + if err := TagActiveDirectoryTierZero(ctx, graphDB); err != nil { collector.Collect(fmt.Errorf("active directory tier zero tagging failed: %w", err)) } diff --git a/cmd/api/src/daemons/datapipe/datapipe.go b/cmd/api/src/daemons/datapipe/datapipe.go index 110396f593..2baf804def 100644 --- a/cmd/api/src/daemons/datapipe/datapipe.go +++ b/cmd/api/src/daemons/datapipe/datapipe.go @@ -19,6 +19,7 @@ package datapipe import ( "context" + "github.com/specterops/bloodhound/src/bootstrap" "os" "path/filepath" "sync" @@ -45,7 +46,6 @@ type Tasker interface { } type Daemon struct { - exitC chan struct{} db database.Database graphdb graph.Database cache cache.Cache @@ -65,14 +65,13 @@ func (s *Daemon) Name() string { return "Data Pipe Daemon" } -func NewDaemon(cfg config.Configuration, db database.Database, graphdb graph.Database, cache cache.Cache, tickInterval time.Duration) *Daemon { +func NewDaemon(ctx context.Context, cfg config.Configuration, connections bootstrap.DatabaseConnections[*database.BloodhoundDB, *graph.DatabaseSwitch], cache cache.Cache, tickInterval time.Duration) *Daemon { return &Daemon{ - exitC: make(chan struct{}), - db: db, - graphdb: graphdb, + db: connections.RDMS, + graphdb: connections.Graph, cache: cache, cfg: cfg, - ctx: context.Background(), + ctx: ctx, analysisRequested: false, lock: &sync.Mutex{}, @@ -154,7 +153,6 @@ func (s *Daemon) Start() { pruningTicker = time.NewTicker(pruningInterval) ) - defer close(s.exitC) defer datapipeLoopTimer.Stop() defer pruningTicker.Stop() @@ -164,6 +162,7 @@ func (s *Daemon) Start() { select { case <-pruningTicker.C: s.clearOrphanedData() + case <-datapipeLoopTimer.C: fileupload.ProcessStaleFileUploadJobs(s.db) @@ -177,21 +176,14 @@ func (s *Daemon) Start() { } datapipeLoopTimer.Reset(s.tickInterval) - case <-s.exitC: + + case <-s.ctx.Done(): return } } } func (s *Daemon) Stop(ctx context.Context) error { - s.exitC <- struct{}{} - - select { - case <-s.exitC: - case <-ctx.Done(): - return ctx.Err() - } - return nil } @@ -228,7 +220,7 @@ func (s *Daemon) clearOrphanedData() { // Check to see if we need to shutdown after every file deletion select { - case <-s.exitC: + case <-s.ctx.Done(): return default: } diff --git a/cmd/api/src/daemons/datapipe/ingest.go b/cmd/api/src/daemons/datapipe/ingest.go index b502fe565f..fe80784502 100644 --- a/cmd/api/src/daemons/datapipe/ingest.go +++ b/cmd/api/src/daemons/datapipe/ingest.go @@ -271,6 +271,7 @@ func IngestRelationship(batch graph.Batch, nowUTC time.Time, nodeIDKind graph.Ki Start: graph.PrepareNode(graph.AsProperties(graph.PropertyMap{ common.ObjectID: nextRel.Source, + common.LastSeen: nowUTC, }), nextRel.SourceType), StartIdentityKind: nodeIDKind, StartIdentityProperties: []string{ @@ -279,6 +280,7 @@ func IngestRelationship(batch graph.Batch, nowUTC time.Time, nodeIDKind graph.Ki End: graph.PrepareNode(graph.AsProperties(graph.PropertyMap{ common.ObjectID: nextRel.Target, + common.LastSeen: nowUTC, }), nextRel.TargetType), EndIdentityKind: nodeIDKind, EndIdentityProperties: []string{ diff --git a/cmd/api/src/database/audit.go b/cmd/api/src/database/audit.go index a2c26841ad..941e6a5695 100644 --- a/cmd/api/src/database/audit.go +++ b/cmd/api/src/database/audit.go @@ -72,14 +72,24 @@ func (s *BloodhoundDB) ListAuditLogs(before, after time.Time, offset, limit int, // This code went through a partial refactor when adding support for new fields. // See the comments here for more information: https://github.com/SpecterOps/BloodHound/pull/297#issuecomment-1887640827 + if filter.SQLString != "" { + result = s.db.Model(&auditLogs).Where(filter.SQLString, filter.Params).Count(&count) + } else { + result = s.db.Model(&auditLogs).Count(&count) + } + + if result.Error != nil { + return nil, 0, CheckError(result) + } + if order != "" && filter.SQLString == "" { - result = cursor.Order(order).Find(&auditLogs).Count(&count) + result = cursor.Order(order).Find(&auditLogs) } else if order != "" && filter.SQLString != "" { - result = cursor.Where(filter.SQLString, filter.Params).Order(order).Find(&auditLogs).Count(&count) + result = cursor.Where(filter.SQLString, filter.Params).Order(order).Find(&auditLogs) } else if order == "" && filter.SQLString != "" { - result = cursor.Where(filter.SQLString, filter.Params).Find(&auditLogs).Count(&count) + result = cursor.Where(filter.SQLString, filter.Params).Find(&auditLogs) } else { - result = cursor.Find(&auditLogs).Count(&count) + result = cursor.Find(&auditLogs) } return auditLogs, int(count), CheckError(result) diff --git a/cmd/api/src/database/audit_test.go b/cmd/api/src/database/audit_test.go new file mode 100644 index 0000000000..683fe5c35c --- /dev/null +++ b/cmd/api/src/database/audit_test.go @@ -0,0 +1,73 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration +// +build integration + +package database_test + +import ( + "github.com/specterops/bloodhound/src/auth" + "github.com/specterops/bloodhound/src/ctx" + "github.com/specterops/bloodhound/src/model" + "github.com/specterops/bloodhound/src/test/integration" + "testing" + "time" +) + +func TestDatabase_ListAuditLogs(t *testing.T) { + var ( + dbInst = integration.OpenDatabase(t) + + auditLogIdFilter = model.QueryParameterFilter{ + Name: "id", + Operator: model.GreaterThan, + Value: "4", + IsStringData: false, + } + auditLogIdFilterMap = model.QueryParameterFilterMap{auditLogIdFilter.Name: model.QueryParameterFilters{auditLogIdFilter}} + ) + + if err := integration.Prepare(dbInst); err != nil { + t.Fatalf("Failed preparing DB: %v", err) + } + + mockCtx := ctx.Context{ + RequestID: "requestID", + AuthCtx: auth.Context{ + Owner: model.User{}, + Session: model.UserSession{}, + }, + } + for i := 0; i < 7; i++ { + if err := dbInst.AppendAuditLog(mockCtx, "CreateUser", model.User{}); err != nil { + t.Fatalf("Error creating audit log: %v", err) + } + } + + if _, count, err := dbInst.ListAuditLogs(time.Now(), time.Now(), 0, 10, "", model.SQLFilter{}); err != nil { + t.Fatalf("Failed to list all audit logs: %v", err) + } else if count != 7 { + t.Fatalf("Expected 7 audit logs to be returned") + } else if filter, err := auditLogIdFilterMap.BuildSQLFilter(); err != nil { + t.Fatalf("Failed to generate SQL Filter: %v", err) + // Limit is set to 1 to verify that count is total filtered count, not response size + } else if _, count, err = dbInst.ListAuditLogs(time.Now(), time.Now(), 0, 1, "", filter); err != nil { + t.Fatalf("Failed to list filtered events: %v", err) + } else if count != 3 { + t.Fatalf("Expected 3 audit logs to be returned") + } +} diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index d3e7f76653..21bbe8de89 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -51,6 +51,7 @@ type Database interface { appcfg.ParameterService appcfg.FeatureFlagService + Close() GetConfigurationParameter(parameter string) (appcfg.Parameter, error) SetConfigurationParameter(appConfig appcfg.Parameter) error GetAllConfigurationParameters() (appcfg.Parameters, error) @@ -151,6 +152,14 @@ type BloodhoundDB struct { idResolver auth.IdentityResolver // TODO: this really needs to be elsewhere. something something separation of concerns } +func (s *BloodhoundDB) Close() { + if sqlDBRef, err := s.db.DB(); err != nil { + log.Errorf("Failed to fetch SQL DB reference from GORM: %v", err) + } else if err := sqlDBRef.Close(); err != nil { + log.Errorf("Failed closing database: %v", err) + } +} + func (s *BloodhoundDB) preload(associations []string) *gorm.DB { cursor := s.db for _, association := range associations { @@ -200,7 +209,7 @@ func (s *BloodhoundDB) Wipe() error { return s.db.Transaction(func(tx *gorm.DB) error { var tables []string - if result := tx.Raw("select table_name from information_schema.tables where table_schema = current_schema()").Scan(&tables); result.Error != nil { + if result := tx.Raw("select table_name from information_schema.tables where table_schema = current_schema() and not table_name ilike '%pg_stat%'").Scan(&tables); result.Error != nil { return result.Error } diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index d8b3fbc6c0..6e58a3d86e 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -68,6 +68,18 @@ func (mr *MockDatabaseMockRecorder) AppendAuditLog(arg0, arg1, arg2 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendAuditLog", reflect.TypeOf((*MockDatabase)(nil).AppendAuditLog), arg0, arg1, arg2) } +// Close mocks base method. +func (m *MockDatabase) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockDatabaseMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDatabase)(nil).Close)) +} + // CreateADDataQualityAggregation mocks base method. func (m *MockDatabase) CreateADDataQualityAggregation(arg0 model.ADDataQualityAggregation) (model.ADDataQualityAggregation, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/saved_queries.go b/cmd/api/src/database/saved_queries.go index 26022fd1e8..7518ac938f 100644 --- a/cmd/api/src/database/saved_queries.go +++ b/cmd/api/src/database/saved_queries.go @@ -19,30 +19,33 @@ package database import ( "github.com/gofrs/uuid" "github.com/specterops/bloodhound/src/model" + "gorm.io/gorm" ) func (s *BloodhoundDB) ListSavedQueries(userID uuid.UUID, order string, filter model.SQLFilter, skip, limit int) (model.SavedQueries, int, error) { var ( queries model.SavedQueries + result *gorm.DB count int64 + cursor = s.Scope(Paginate(skip, limit)).Where("user_id = ?", userID) ) - cursor := s.Scope(Paginate(skip, limit)).Where("user_id = ?", userID) - if filter.SQLString != "" { cursor = cursor.Where(filter.SQLString, filter.Params) + result = s.db.Model(&queries).Where("user_id = ?", userID).Where(filter.SQLString, filter.Params).Count(&count) + } else { + result = s.db.Model(&queries).Where("user_id = ?", userID).Count(&count) } - if order != "" { - cursor = cursor.Order(order) - } - - result := s.db.Where("user_id = ?", userID).Find(&queries).Count(&count) if result.Error != nil { return queries, 0, result.Error } + if order != "" { + cursor = cursor.Order(order) + } result = cursor.Find(&queries) + return queries, int(count), CheckError(result) } diff --git a/cmd/api/src/database/saved_queries_test.go b/cmd/api/src/database/saved_queries_test.go new file mode 100644 index 0000000000..f2e8ee1f2b --- /dev/null +++ b/cmd/api/src/database/saved_queries_test.go @@ -0,0 +1,69 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration +// +build integration + +package database_test + +import ( + "fmt" + "github.com/gofrs/uuid" + "github.com/specterops/bloodhound/src/model" + "github.com/specterops/bloodhound/src/test/integration" + "github.com/stretchr/testify/require" + "testing" +) + +func TestSavedQueries_ListSavedQueries(t *testing.T) { + var ( + dbInst = integration.OpenDatabase(t) + + savedQueriesFilter = model.QueryParameterFilter{ + Name: "id", + Operator: model.GreaterThan, + Value: "4", + IsStringData: false, + } + savedQueriesFilterMap = model.QueryParameterFilterMap{savedQueriesFilter.Name: model.QueryParameterFilters{savedQueriesFilter}} + ) + + if err := integration.Prepare(dbInst); err != nil { + t.Fatalf("Failed preparing DB: %v", err) + } + + userUUID, err := uuid.NewV4() + require.Nil(t, err) + + for i := 0; i < 7; i++ { + if _, err := dbInst.CreateSavedQuery(userUUID, fmt.Sprintf("saved_query_%d", i), ""); err != nil { + t.Fatalf("Error creating audit log: %v", err) + } + } + + if _, count, err := dbInst.ListSavedQueries(userUUID, "", model.SQLFilter{}, 0, 10); err != nil { + t.Fatalf("Failed to list all saved queries: %v", err) + } else if count != 7 { + t.Fatalf("Expected 7 saved queries to be returned") + } else if filter, err := savedQueriesFilterMap.BuildSQLFilter(); err != nil { + t.Fatalf("Failed to generate SQL Filter: %v", err) + // Limit is set to 1 to verify that count is total filtered count, not response size + } else if _, count, err = dbInst.ListSavedQueries(userUUID, "", filter, 0, 1); err != nil { + t.Fatalf("Failed to list filtered saved queries: %v", err) + } else if count != 3 { + t.Fatalf("Expected 3 saved queries to be returned") + } +} diff --git a/cmd/api/src/go.mod b/cmd/api/src/go.mod index 7f6cd7b573..58c637338b 100644 --- a/cmd/api/src/go.mod +++ b/cmd/api/src/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/src -go 1.20 +go 1.21 require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 @@ -32,8 +32,10 @@ require ( github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/schema v1.2.0 + github.com/jackc/pgx/v5 v5.5.1 github.com/jedib0t/go-pretty/v6 v6.4.6 github.com/mattermost/xml-roundtrip-validator v0.1.0 + github.com/neo4j/neo4j-go-driver/v5 v5.9.0 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.4.0 github.com/prometheus/client_golang v1.16.0 diff --git a/cmd/api/src/go.sum b/cmd/api/src/go.sum index 5b264c4adb..0a30eb18e9 100644 --- a/cmd/api/src/go.sum +++ b/cmd/api/src/go.sum @@ -129,11 +129,14 @@ github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgS github.com/jackc/pgx/v4 v4.16.1/go.mod h1:SIhx0D5hoADaiXZVyv+3gSm3LCIIINTVO0PficsvWGQ= github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0= github.com/jackc/pgx/v4 v4.18.1/go.mod h1:FydWkUyadDmdNH/mHnGob881GawxeEm7TcMCzkb+qQE= +github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jedib0t/go-pretty/v6 v6.4.6 h1:v6aG9h6Uby3IusSSEjHaZNXpHFhzqMmjXcPq1Rjl9Jw= github.com/jedib0t/go-pretty/v6 v6.4.6/go.mod h1:Ndk3ase2CkQbXLLNf5QDHoYb6J9WtVfmHZu9n8rk2xs= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -182,6 +185,7 @@ github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWV github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/neo4j/neo4j-go-driver/v5 v5.9.0 h1:TYxT0RSiwnvVFia90V7TLnRXv8HkdQQ6rTUaPVoyZ+w= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -294,6 +298,7 @@ golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/cmd/api/src/go.tools.mod b/cmd/api/src/go.tools.mod index bdd43e8eb8..88e074ea47 100644 --- a/cmd/api/src/go.tools.mod +++ b/cmd/api/src/go.tools.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/src -go 1.20 +go 1.21 require ( go.uber.org/mock v1.5.0 // indirect diff --git a/cmd/api/src/migrations/graph.go b/cmd/api/src/migrations/graph.go index 875ab5ee67..b9af27174d 100644 --- a/cmd/api/src/migrations/graph.go +++ b/cmd/api/src/migrations/graph.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package migrations @@ -19,12 +19,12 @@ package migrations import ( "context" "fmt" - - "github.com/specterops/bloodhound/src/version" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/dawgs/query" + "github.com/specterops/bloodhound/graphschema" "github.com/specterops/bloodhound/graphschema/common" "github.com/specterops/bloodhound/log" + "github.com/specterops/bloodhound/src/version" ) type Migration struct { @@ -40,7 +40,12 @@ func NewGraphMigrator(db graph.Database) *GraphMigrator { return &GraphMigrator{db: db} } -func (s *GraphMigrator) Migrate() error { +func (s *GraphMigrator) Migrate(ctx context.Context, schema graph.Schema) error { + // Assert the schema first + if err := s.db.AssertSchema(ctx, schema); err != nil { + return err + } + // Perform stepwise migrations if err := s.executeStepwiseMigrations(); err != nil { return err @@ -96,8 +101,10 @@ func (s *GraphMigrator) getMigrationData() (version.Version, error) { node *graph.Node currentMigration version.Version ) + if err := s.db.ReadTransaction(context.Background(), func(tx graph.Transaction) error { var err error + if node, err = tx.Nodes().Filterf(func() graph.Criteria { return query.Kind(query.Node(), common.MigrationData) }).First(); err != nil { @@ -140,7 +147,7 @@ func (s *GraphMigrator) executeMigrations(target version.Version) error { } func (s *GraphMigrator) executeStepwiseMigrations() error { - if err := s.db.AssertSchema(context.Background(), CurrentSchema()); err != nil { + if err := s.db.AssertSchema(context.Background(), graphschema.DefaultGraphSchema()); err != nil { return fmt.Errorf("error asserting current schema: %w", err) } @@ -149,7 +156,9 @@ func (s *GraphMigrator) executeStepwiseMigrations() error { if err := s.createMigrationData(); err != nil { return fmt.Errorf("could not create graph db migration data: %w", err) } + currentVersion := version.GetVersion() + log.Infof("This is a new graph database. Creating a migration entry for GraphDB version %s", currentVersion) return s.updateMigrationData(currentVersion) } else { diff --git a/cmd/api/src/migrations/schema.go b/cmd/api/src/migrations/schema.go deleted file mode 100644 index d1c575325a..0000000000 --- a/cmd/api/src/migrations/schema.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2023 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package migrations - -import ( - "github.com/specterops/bloodhound/dawgs/graph" - "github.com/specterops/bloodhound/graphschema/ad" - "github.com/specterops/bloodhound/graphschema/azure" - "github.com/specterops/bloodhound/graphschema/common" -) - -func CurrentSchema() *graph.Schema { - bhSchema := graph.NewSchema() - - bhSchema.DefineKinds(ad.NodeKinds()...) - bhSchema.DefineKinds(azure.NodeKinds()...) - - bhSchema.ConstrainProperty(common.ObjectID.String(), graph.FullTextSearchIndex) - - bhSchema.IndexProperty(common.Name.String(), graph.FullTextSearchIndex) - bhSchema.IndexProperty(common.SystemTags.String(), graph.FullTextSearchIndex) - bhSchema.IndexProperty(common.UserTags.String(), graph.FullTextSearchIndex) - - bhSchema.ForKinds(ad.Entity).Index(ad.DistinguishedName.String(), graph.BTreeIndex) - - bhSchema.ForKinds(ad.NodeKinds()...). - Index(ad.DomainFQDN.String(), graph.BTreeIndex). - Index(ad.DomainSID.String(), graph.BTreeIndex) - - bhSchema.ForKinds(azure.NodeKinds()...). - Index(azure.TenantID.String(), graph.BTreeIndex) - - bhSchema.ForKinds(ad.RootCA, ad.EnterpriseCA, ad.AIACA). - Index(ad.CertThumbprint.String(), graph.BTreeIndex) - - return bhSchema -} diff --git a/cmd/api/src/queries/graph.go b/cmd/api/src/queries/graph.go index d16348b855..0356c3ca5b 100644 --- a/cmd/api/src/queries/graph.go +++ b/cmd/api/src/queries/graph.go @@ -22,6 +22,10 @@ import ( "bytes" "context" "fmt" + "github.com/specterops/bloodhound/cypher/backend/cypher" + "github.com/specterops/bloodhound/cypher/backend/pgsql" + "github.com/specterops/bloodhound/dawgs/drivers/pg" + "github.com/specterops/bloodhound/src/config" "github.com/specterops/bloodhound/src/services/agi" "net/http" "net/url" @@ -149,18 +153,18 @@ type GraphQuery struct { Cache cache.Cache SlowQueryThreshold int64 // Threshold in milliseconds DisableCypherQC bool - cypherEmitter frontend.Emitter - strippedCypherEmitter frontend.Emitter + cypherEmitter cypher.Emitter + strippedCypherEmitter cypher.Emitter } -func NewGraphQuery(graphDB graph.Database, cache cache.Cache, slowQueryThreshold int64, disableCypherQC bool) *GraphQuery { +func NewGraphQuery(graphDB graph.Database, cache cache.Cache, cfg config.Configuration) *GraphQuery { return &GraphQuery{ Graph: graphDB, Cache: cache, - SlowQueryThreshold: slowQueryThreshold, - DisableCypherQC: disableCypherQC, - cypherEmitter: frontend.NewCypherEmitter(false), - strippedCypherEmitter: frontend.NewCypherEmitter(true), + SlowQueryThreshold: cfg.SlowQueryThreshold, + DisableCypherQC: cfg.DisableCypherQC, + cypherEmitter: cypher.NewCypherEmitter(false), + strippedCypherEmitter: cypher.NewCypherEmitter(true), } } @@ -262,7 +266,9 @@ func (s *GraphQuery) GetAllShortestPaths(ctx context.Context, startNodeID string return tx.Relationships().Filter(query.And(criteria...)).FetchAllShortestPaths(func(cursor graph.Cursor[graph.Path]) error { for path := range cursor.Chan() { - paths.AddPath(path) + if len(path.Edges) > 0 { + paths.AddPath(path) + } } return cursor.Error() @@ -280,7 +286,7 @@ var groupFilter = query.Not( ), ) -func searchNodeByKindAndEqualsName(kind graph.Kind, name string) graph.Criteria { +func SearchNodeByKindAndEqualsNameCriteria(kind graph.Kind, name string) graph.Criteria { return query.And( query.Kind(query.Node(), kind), query.Or( @@ -338,7 +344,7 @@ func (s *GraphQuery) SearchNodesByName(ctx context.Context, nodeKinds graph.Kind for _, kind := range nodeKinds { if err := s.Graph.ReadTransaction(ctx, func(tx graph.Transaction) error { - if exactMatchNodes, err := ops.FetchNodes(tx.Nodes().Filter(searchNodeByKindAndEqualsName(kind, formattedName))); err != nil { + if exactMatchNodes, err := ops.FetchNodes(tx.Nodes().Filter(SearchNodeByKindAndEqualsNameCriteria(kind, formattedName))); err != nil { return err } else { @@ -361,9 +367,9 @@ func (s *GraphQuery) SearchNodesByName(ctx context.Context, nodeKinds graph.Kind } type preparedQuery struct { - cypher string - strippedCypher string - complexity *analyzer.ComplexityMeasure + query string + strippedQuery string + complexity *analyzer.ComplexityMeasure } func (s *GraphQuery) prepareGraphQuery(rawCypher string, disableCypherQC bool) (preparedQuery, error) { @@ -379,13 +385,25 @@ func (s *GraphQuery) prepareGraphQuery(rawCypher string, disableCypherQC bool) ( return graphQuery, newQueryError(err) } else if !disableCypherQC && complexityMeasure.Weight > MaxQueryComplexityWeightAllowed { return graphQuery, newQueryError(ErrCypherQueryToComplex) + } else if pgDB, isPG := s.Graph.(*pg.Driver); isPG { + if _, err := pgsql.Translate(queryModel, pgDB.KindMapper()); err != nil { + return graphQuery, newQueryError(err) + } + + if err := pgsql.NewEmitter(false, pgDB.KindMapper()).Write(queryModel, buffer); err != nil { + return graphQuery, err + } else { + graphQuery.query = buffer.String() + } + + return graphQuery, nil } else { graphQuery.complexity = complexityMeasure if err := s.cypherEmitter.Write(queryModel, buffer); err != nil { return graphQuery, newQueryError(err) } else { - graphQuery.cypher = buffer.String() + graphQuery.query = buffer.String() } buffer.Reset() @@ -393,7 +411,7 @@ func (s *GraphQuery) prepareGraphQuery(rawCypher string, disableCypherQC bool) ( if err := s.strippedCypherEmitter.Write(queryModel, buffer); err != nil { return graphQuery, newQueryError(err) } else { - graphQuery.strippedCypher = buffer.String() + graphQuery.strippedQuery = buffer.String() } } @@ -410,11 +428,11 @@ func (s *GraphQuery) RawCypherSearch(ctx context.Context, rawCypher string, incl return graphResponse, err } else { logEvent := log.WithLevel(log.LevelInfo) - logEvent.Str("query", preparedQuery.strippedCypher) + logEvent.Str("query", preparedQuery.strippedQuery) logEvent.Msg("Executing user cypher query") return graphResponse, s.Graph.ReadTransaction(ctx, func(tx graph.Transaction) error { - if pathSet, err := ops.FetchPathSetByQuery(tx, preparedQuery.cypher); err != nil { + if pathSet, err := ops.FetchPathSetByQuery(tx, preparedQuery.query); err != nil { return err } else { graphResponse.AddPathSet(pathSet, includeProperties) diff --git a/cmd/api/src/queries/graph_integration_test.go b/cmd/api/src/queries/graph_integration_test.go index 957e3713b9..9de0616cb5 100644 --- a/cmd/api/src/queries/graph_integration_test.go +++ b/cmd/api/src/queries/graph_integration_test.go @@ -22,6 +22,8 @@ package queries_test import ( "context" "encoding/json" + schema "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/src/config" "testing" adAnalysis "github.com/specterops/bloodhound/analysis/ad" @@ -39,18 +41,19 @@ import ( ) func TestSearchNodesByName_ExactMatch(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup( - func(harness *integration.HarnessDetails) { + func(harness *integration.HarnessDetails) error { harness.SearchHarness.Setup(testContext) + return nil }, - func(harness integration.HarnessDetails, db graph.Database) error { + func(harness integration.HarnessDetails, db graph.Database) { var ( userWanted = "USER NUMBER ONE" skip = 0 limit = 10 - graphQuery = queries.NewGraphQuery(testContext.GraphDB, cache.Cache{}, 0, false) + graphQuery = queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) ) results, err := graphQuery.SearchNodesByName(context.Background(), graph.Kinds{azure.Entity, ad.Entity}, userWanted, skip, limit) @@ -58,96 +61,92 @@ func TestSearchNodesByName_ExactMatch(t *testing.T) { require.Nil(t, err) expectedUser := results[0] require.Equal(t, expectedUser.Name, userWanted) - - return nil }) } func TestSearchNodesByName_FuzzyMatch(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup( - func(harness *integration.HarnessDetails) { + func(harness *integration.HarnessDetails) error { harness.SearchHarness.Setup(testContext) + return nil }, - func(harness integration.HarnessDetails, db graph.Database) error { + func(harness integration.HarnessDetails, db graph.Database) { var ( userWanted = "USER NUMBER" skip = 0 limit = 10 - graphQuery = queries.NewGraphQuery(testContext.GraphDB, cache.Cache{}, 0, false) + graphQuery = queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) ) results, err := graphQuery.SearchNodesByName(context.Background(), graph.Kinds{azure.Entity, ad.Entity}, userWanted, skip, limit) require.Nil(t, err) require.Equal(t, 5, len(results), "All users that contain `USER NUMBER` should be returned ") - - return nil }) } func TestSearchNodesByName_NoADLocalGroup(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup( - func(harness *integration.HarnessDetails) { + func(harness *integration.HarnessDetails) error { harness.SearchHarness.Setup(testContext) + return nil }, - func(harness integration.HarnessDetails, db graph.Database) error { + func(harness integration.HarnessDetails, db graph.Database) { var ( userWanted = "Remote Desktop" skip = 0 limit = 10 - graphQuery = queries.NewGraphQuery(testContext.GraphDB, cache.Cache{}, 0, false) + graphQuery = queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) ) results, err := graphQuery.SearchNodesByName(context.Background(), graph.Kinds{azure.Entity, ad.Entity}, userWanted, skip, limit) require.Nil(t, err) require.Equal(t, 0, len(results), "No ADLocalGroup nodes should be returned ") - - return nil }) } func TestSearchNodesByName_GroupLocalGroupCorrect(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup( - func(harness *integration.HarnessDetails) { + func(harness *integration.HarnessDetails) error { harness.SearchHarness.Setup(testContext) + return nil }, - func(harness integration.HarnessDetails, db graph.Database) error { + func(harness integration.HarnessDetails, db graph.Database) { var ( - userWanted = "Account Op" - skip = 0 - limit = 10 - graphQuery = queries.NewGraphQuery(testContext.GraphDB, cache.Cache{}, 0, false) + groupWanted = "Account Op" + skip = 0 + limit = 10 + graphQuery = queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) ) - results, err := graphQuery.SearchNodesByName(context.Background(), graph.Kinds{azure.Entity, ad.Entity}, userWanted, skip, limit) + results, err := graphQuery.SearchNodesByName(context.Background(), graph.Kinds{azure.Entity, ad.Entity}, groupWanted, skip, limit) require.Nil(t, err) require.Equal(t, 1, len(results), ":ADLocalGroup nodes should return if they are also :Group nodes") - - return nil }) } func TestSearchNodesByName_ExactMatch_ObjectID(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup( - func(harness *integration.HarnessDetails) { + func(harness *integration.HarnessDetails) error { harness.SearchHarness.Setup(testContext) + return nil }, - func(harness integration.HarnessDetails, db graph.Database) error { + func(harness integration.HarnessDetails, db graph.Database) { var ( userObjectId = harness.SearchHarness.User1.Properties.Get(common.ObjectID.String()) skip = 0 limit = 10 - graphQuery = queries.NewGraphQuery(testContext.GraphDB, cache.Cache{}, 0, false) + graphQuery = queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) ) searchQuery, _ := userObjectId.String() @@ -159,17 +158,15 @@ func TestSearchNodesByName_ExactMatch_ObjectID(t *testing.T) { require.Nil(t, err) require.Equal(t, 1, len(results), "Only one user can match exactly one Object ID") require.Equal(t, searchQuery, actual.ObjectID) - - return nil }) } func TestGetEntityResults(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) queryCache, err := cache.NewCache(cache.Config{MaxSize: 1}) require.Nil(t, err) - testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) error { + testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) { objectID, err := harness.InboundControl.ControlledUser.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) @@ -201,17 +198,15 @@ func TestGetEntityResults(t *testing.T) { require.Equal(t, 1, results.Skip) require.LessOrEqual(t, 2, len(results.Data.([]any))) require.Equal(t, 0, queryCache.Len()) - - return nil }) } func TestGetEntityResults_QueryShorterThanSlowQueryThreshold(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) queryCache, err := cache.NewCache(cache.Config{MaxSize: 1}) require.Nil(t, err) - testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) error { + testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) { objectID, err := harness.InboundControl.ControlledUser.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) @@ -244,17 +239,15 @@ func TestGetEntityResults_QueryShorterThanSlowQueryThreshold(t *testing.T) { require.Equal(t, 1, results.Skip) require.LessOrEqual(t, 2, len(results.Data.([]any))) require.Equal(t, 0, queryCache.Len()) - - return nil }) } func TestGetEntityResults_Cache(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) queryCache, err := cache.NewCache(cache.Config{MaxSize: 2}) require.Nil(t, err) - testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) error { + testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) { objectID, err := harness.InboundControl.ControlledUser.Properties.Get(common.ObjectID.String()).String() require.Nil(t, err) @@ -305,15 +298,13 @@ func TestGetEntityResults_Cache(t *testing.T) { require.Equal(t, 1, results.Skip) require.LessOrEqual(t, 2, len(results.Data.([]any))) require.Equal(t, 1, queryCache.Len()) - - return nil }) } func TestGetAssetGroupComboNode(t *testing.T) { - testContext := integration.NewGraphTestContext(t) - testContext.TransactionalTest(func(harness integration.HarnessDetails, tx graph.Transaction) { - graphQuery := queries.NewGraphQuery(testContext.GraphDB, cache.Cache{}, 0, false) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) + testContext.DatabaseTest(func(harness integration.HarnessDetails, db graph.Database) { + graphQuery := queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) comboNode, err := graphQuery.GetAssetGroupComboNode(context.Background(), "", ad.AdminTierZero) require.Nil(t, err) @@ -330,9 +321,9 @@ func TestGetAssetGroupComboNode(t *testing.T) { } func TestGraphQuery_GetAllShortestPaths(t *testing.T) { - testContext := integration.NewGraphTestContext(t) + testContext := integration.NewGraphTestContext(t, schema.DefaultGraphSchema()) testContext.DatabaseTestWithSetup( - func(harness *integration.HarnessDetails) { + func(harness *integration.HarnessDetails) error { var ( userA = testContext.NewNode(graph.AsProperties(graph.PropertyMap{ common.Name: "A", @@ -353,9 +344,11 @@ func TestGraphQuery_GetAllShortestPaths(t *testing.T) { testContext.NewRelationship(userA, groupA, ad.MemberOf) testContext.NewRelationship(groupA, computer, ad.GenericAll) testContext.NewRelationship(userA, computer, ad.GenericWrite) + + return nil }, - func(harness integration.HarnessDetails, db graph.Database) error { - graphQuery := queries.NewGraphQuery(testContext.GraphDB, cache.Cache{}, 0, false) + func(harness integration.HarnessDetails, db graph.Database) { + graphQuery := queries.NewGraphQuery(db, cache.Cache{}, config.Configuration{}) paths, err := graphQuery.GetAllShortestPaths(context.Background(), "A", "C", query.KindIn(query.Relationship(), ad.Relationships()...)) require.Nil(t, err) @@ -372,7 +365,5 @@ func TestGraphQuery_GetAllShortestPaths(t *testing.T) { require.Nil(t, err) require.Equal(t, 0, len(paths)) - - return nil }) } diff --git a/cmd/api/src/queries/graph_test.go b/cmd/api/src/queries/graph_test.go index 3fff1175f4..f222d79f41 100644 --- a/cmd/api/src/queries/graph_test.go +++ b/cmd/api/src/queries/graph_test.go @@ -19,6 +19,7 @@ package queries_test import ( "context" "fmt" + "github.com/specterops/bloodhound/src/config" "net/http" "net/url" "testing" @@ -45,7 +46,7 @@ func TestGraphQuery_RawCypherSearch(t *testing.T) { var ( mockCtrl = gomock.NewController(t) mockGraphDB = graphMocks.NewMockDatabase(mockCtrl) - gq = queries.NewGraphQuery(mockGraphDB, cache.Cache{}, 0, false) + gq = queries.NewGraphQuery(mockGraphDB, cache.Cache{}, config.Configuration{}) outerBHCtxInst = &bhCtx.Context{ StartTime: time.Now(), Timeout: bhCtx.RequestedWaitDuration{ diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go new file mode 100644 index 0000000000..223a51d5a9 --- /dev/null +++ b/cmd/api/src/services/entrypoint.go @@ -0,0 +1,113 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package services + +import ( + "context" + "fmt" + schema "github.com/specterops/bloodhound/graphschema" + "github.com/specterops/bloodhound/log" + "github.com/specterops/bloodhound/src/bootstrap" + "github.com/specterops/bloodhound/src/queries" + "time" + + "github.com/specterops/bloodhound/cache" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/src/api" + "github.com/specterops/bloodhound/src/api/registration" + "github.com/specterops/bloodhound/src/api/router" + "github.com/specterops/bloodhound/src/auth" + "github.com/specterops/bloodhound/src/config" + "github.com/specterops/bloodhound/src/daemons" + "github.com/specterops/bloodhound/src/daemons/api/bhapi" + "github.com/specterops/bloodhound/src/daemons/api/toolapi" + "github.com/specterops/bloodhound/src/daemons/datapipe" + "github.com/specterops/bloodhound/src/daemons/gc" + "github.com/specterops/bloodhound/src/database" + "github.com/specterops/bloodhound/src/model/appcfg" +) + +// ConnectPostgres initializes a connection to PG, and returns errors if any +func ConnectPostgres(cfg config.Configuration) (*database.BloodhoundDB, error) { + if db, err := database.OpenDatabase(cfg.Database.PostgreSQLConnectionString()); err != nil { + return nil, fmt.Errorf("error while attempting to create database connection: %w", err) + } else { + return database.NewBloodhoundDB(db, auth.NewIdentityResolver()), nil + } +} + +// ConnectDatabases initializes connections to PG and connection, and returns errors if any +func ConnectDatabases(ctx context.Context, cfg config.Configuration) (bootstrap.DatabaseConnections[*database.BloodhoundDB, *graph.DatabaseSwitch], error) { + connections := bootstrap.DatabaseConnections[*database.BloodhoundDB, *graph.DatabaseSwitch]{} + + if db, err := ConnectPostgres(cfg); err != nil { + return connections, err + } else if graphDB, err := bootstrap.ConnectGraph(ctx, cfg); err != nil { + return connections, err + } else { + connections.RDMS = db + connections.Graph = graphDB + + return connections, nil + } +} + +func Entrypoint(ctx context.Context, cfg config.Configuration, connections bootstrap.DatabaseConnections[*database.BloodhoundDB, *graph.DatabaseSwitch]) ([]daemons.Daemon, error) { + if !cfg.DisableMigrations { + if err := bootstrap.MigrateDB(cfg, connections.RDMS); err != nil { + return nil, fmt.Errorf("rdms migration error: %w", err) + } else if err := bootstrap.MigrateGraph(ctx, connections.Graph, schema.DefaultGraphSchema()); err != nil { + return nil, fmt.Errorf("graph migration error: %w", err) + } + } else { + log.Infof("Database migrations are disabled per configuration") + } + + if apiCache, err := cache.NewCache(cache.Config{MaxSize: cfg.MaxAPICacheSize}); err != nil { + return nil, fmt.Errorf("failed to create in-memory cache for API: %w", err) + } else if graphQueryCache, err := cache.NewCache(cache.Config{MaxSize: cfg.MaxAPICacheSize}); err != nil { + return nil, fmt.Errorf("failed to create in-memory cache for graph queries: %w", err) + } else if collectorManifests, err := cfg.SaveCollectorManifests(); err != nil { + return nil, fmt.Errorf("failed to save collector manifests: %w", err) + } else { + var ( + graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) + datapipeDaemon = datapipe.NewDaemon(ctx, cfg, connections, graphQueryCache, time.Duration(cfg.DatapipeInterval)*time.Second) + routerInst = router.NewRouter(cfg, auth.NewAuthorizer(), bootstrap.ContentSecurityPolicy) + ctxInitializer = database.NewContextInitializer(connections.RDMS) + authenticator = api.NewAuthenticator(cfg, connections.RDMS, ctxInitializer) + ) + + registration.RegisterFossGlobalMiddleware(&routerInst, cfg, auth.NewIdentityResolver(), authenticator) + registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, collectorManifests, authenticator, datapipeDaemon) + + // Set neo4j batch and flush sizes + neo4jParameters := appcfg.GetNeo4jParameters(connections.RDMS) + connections.Graph.SetBatchWriteSize(neo4jParameters.BatchWriteSize) + connections.Graph.SetWriteFlushSize(neo4jParameters.WriteFlushSize) + + // Trigger analysis on first start + datapipeDaemon.RequestAnalysis() + + return []daemons.Daemon{ + bhapi.NewDaemon(cfg, routerInst.Handler()), + toolapi.NewDaemon(ctx, connections, cfg, schema.DefaultGraphSchema()), + gc.NewDataPruningDaemon(connections.RDMS), + datapipeDaemon, + }, nil + } +} diff --git a/cmd/api/src/test/ctrl.go b/cmd/api/src/test/ctrl.go index 97c55b2c4b..4ee7864305 100644 --- a/cmd/api/src/test/ctrl.go +++ b/cmd/api/src/test/ctrl.go @@ -1,24 +1,75 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package test +import ( + "context" + "time" +) + type Controller interface { Cleanup(func()) Errorf(format string, args ...any) Fatalf(format string, args ...any) FailNow() } + +type Context interface { + Controller + context.Context +} + +type controllerInstance struct { + Controller + ctx context.Context +} + +func (s controllerInstance) Deadline() (deadline time.Time, ok bool) { + return s.ctx.Deadline() +} + +func (s controllerInstance) Done() <-chan struct{} { + return s.ctx.Done() +} + +func (s controllerInstance) Err() error { + return s.ctx.Err() +} + +func (s controllerInstance) Value(key any) any { + return s.ctx.Value(key) +} + +func (s controllerInstance) Context() context.Context { + return s +} + +func WithContext(parentCtx context.Context, controller Controller) Context { + testCtx, doneFunc := context.WithCancel(parentCtx) + + // Ensure the done function for the context is called by test cleanup + controller.Cleanup(doneFunc) + + return controllerInstance{ + Controller: controller, + ctx: testCtx, + } +} + +func NewContext(controller Controller) Context { + return WithContext(context.Background(), controller) +} diff --git a/cmd/api/src/test/fixtures/fixtures/expected_ingest.go b/cmd/api/src/test/fixtures/fixtures/expected_ingest.go index 7b06ce0c1a..427811009f 100644 --- a/cmd/api/src/test/fixtures/fixtures/expected_ingest.go +++ b/cmd/api/src/test/fixtures/fixtures/expected_ingest.go @@ -18,8 +18,7 @@ package fixtures import ( "bytes" - - "github.com/specterops/bloodhound/cypher/frontend" + "github.com/specterops/bloodhound/cypher/backend/cypher" "github.com/specterops/bloodhound/cypher/model" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/dawgs/query" @@ -245,11 +244,19 @@ var ( query.Equals(query.EndProperty(common.ObjectID.String()), "S-1-5-21-3130019616-2776909439-2417379446-2117"), query.Equals(query.RelationshipProperty(ad.LogonType.String()), 2)), } + v6ingestRelationshipAssertionCriteria = []graph.Criteria{ + query.And( + query.Kind(query.Start(), ad.Computer), + query.Equals(query.StartProperty(common.ObjectID.String()), "S-1-5-21-3130019616-2776909439-2417379446-1001"), + query.Kind(query.Relationship(), ad.DCFor), + query.Kind(query.End(), ad.Domain), + query.Equals(query.EndProperty(common.ObjectID.String()), "S-1-5-21-3130019616-2776909439-2417379446")), + } ) func FormatQueryComponent(criteria graph.Criteria) string { var ( - emitter = frontend.NewCypherEmitter(false) + emitter = cypher.NewCypherEmitter(false) stringBuffer = &bytes.Buffer{} ) @@ -266,3 +273,10 @@ func IngestAssertions(testCtrl test.Controller, tx graph.Transaction) { require.Nilf(testCtrl, err, "Unable to find an expected relationship: %s", FormatQueryComponent(assertionCriteria)) } } + +func IngestAssertionsv6(testCtrl test.Controller, tx graph.Transaction) { + for _, assertionCriteria := range v6ingestRelationshipAssertionCriteria { + _, err := tx.Relationships().Filter(assertionCriteria).First() + require.Nilf(testCtrl, err, "Unable to find an expected relationship: %s", FormatQueryComponent(assertionCriteria)) + } +} diff --git a/cmd/api/src/test/fixtures/fixtures/expected_ingest_adcs.go b/cmd/api/src/test/fixtures/fixtures/expected_ingest_adcs.go index e6b5111810..f5141d48ac 100644 --- a/cmd/api/src/test/fixtures/fixtures/expected_ingest_adcs.go +++ b/cmd/api/src/test/fixtures/fixtures/expected_ingest_adcs.go @@ -549,6 +549,9 @@ func IngestADCSAssertions(testCtrl test.Controller, tx graph.Transaction) { for _, assertionCriteria := range nodeAssertionCriteria { _, err := tx.Nodes().Filter(assertionCriteria).First() + if err != nil { + tx.Nodes().Filter(assertionCriteria).First() + } require.Nilf(testCtrl, err, "Node assertion failed: %s", FormatQueryComponent(assertionCriteria)) } } diff --git a/cmd/api/src/test/fixtures/fixtures/v6/ingest/computers.json b/cmd/api/src/test/fixtures/fixtures/v6/ingest/computers.json index 4442c3c053..b4b455cbec 100644 --- a/cmd/api/src/test/fixtures/fixtures/v6/ingest/computers.json +++ b/cmd/api/src/test/fixtures/fixtures/v6/ingest/computers.json @@ -343,7 +343,9 @@ ], "ObjectIdentifier": "S-1-5-21-3130019616-2776909439-2417379446-1001", "IsDeleted": false, - "IsACLProtected": false + "IsACLProtected": false, + "IsDC": true, + "DomainSID": "S-1-5-21-3130019616-2776909439-2417379446" }, { "Properties": { diff --git a/cmd/api/src/test/integration/context.go b/cmd/api/src/test/integration/context.go new file mode 100644 index 0000000000..ca425003e7 --- /dev/null +++ b/cmd/api/src/test/integration/context.go @@ -0,0 +1,82 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package integration + +import ( + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/src/test" +) + +type GraphContext struct { + Database graph.Database + schema graph.Schema +} + +func (s *GraphContext) BatchOperation(ctx test.Context, delegate graph.BatchDelegate) { + test.RequireNilErr(ctx, s.Database.BatchOperation(ctx, delegate)) +} + +func (s *GraphContext) ReadTransaction(ctx test.Context, delegate graph.TransactionDelegate) { + test.RequireNilErr(ctx, s.Database.WriteTransaction(ctx, delegate)) +} + +func (s *GraphContext) WriteTransaction(ctx test.Context, delegate graph.TransactionDelegate) { + test.RequireNilErr(ctx, s.Database.WriteTransaction(ctx, delegate)) +} + +func (s *GraphContext) wipe(ctx test.Context) { + s.WriteTransaction(ctx, func(tx graph.Transaction) error { + if nodeCount, err := tx.Nodes().Count(); err != nil { + return err + } else if nodeCount > 0 { + return tx.Nodes().Delete() + } + + return nil + }) +} + +func (s *GraphContext) Begin(ctx test.Context) { + // Clear the graph to ensure a clean slate + s.wipe(ctx) + + // Assert the graph schema before continuing + test.RequireNilErr(ctx, s.Database.AssertSchema(ctx, s.schema)) +} + +func (s *GraphContext) End(t test.Context) { + if err := s.Database.Close(t); err != nil { + t.Fatalf("Error encoutered while closing the database: %v", err) + } +} + +func NewGraphContext(ctx test.Context, schema graph.Schema) *GraphContext { + graphContext := &GraphContext{ + schema: schema, + Database: OpenGraphDB(ctx), + } + + // Initialize the graph context + graphContext.Begin(ctx) + + // Ensure that the test cleans up after itself + ctx.Cleanup(func() { + graphContext.End(ctx) + }) + + return graphContext +} diff --git a/cmd/api/src/test/integration/dawgs.go b/cmd/api/src/test/integration/dawgs.go index 76a0a99d76..28763cd457 100644 --- a/cmd/api/src/test/integration/dawgs.go +++ b/cmd/api/src/test/integration/dawgs.go @@ -17,13 +17,15 @@ package integration import ( + "context" "github.com/specterops/bloodhound/dawgs" "github.com/specterops/bloodhound/dawgs/drivers/neo4j" + "github.com/specterops/bloodhound/dawgs/drivers/pg" "github.com/specterops/bloodhound/dawgs/graph" + schema "github.com/specterops/bloodhound/graphschema" "github.com/specterops/bloodhound/src/config" "github.com/specterops/bloodhound/src/test" "github.com/specterops/bloodhound/src/test/integration/utils" - "github.com/stretchr/testify/require" ) func LoadConfiguration(testCtrl test.Controller) config.Configuration { @@ -36,16 +38,30 @@ func LoadConfiguration(testCtrl test.Controller) config.Configuration { return cfg } -func OpenPostgresqlGDB(testCtrl test.Controller) graph.Database { - graphDatabase, err := dawgs.Open(neo4j.DriverName, dawgs.Config{DriverCfg: LoadConfiguration(testCtrl).Database.PostgreSQLConnectionString()}) - require.Nilf(testCtrl, err, "Failed connecting to graph database: %v", err) +func OpenGraphDB(testCtrl test.Controller) graph.Database { + var ( + cfg = LoadConfiguration(testCtrl) + graphDatabase graph.Database + err error + ) - return graphDatabase -} + switch cfg.GraphDriver { + case pg.DriverName: + graphDatabase, err = dawgs.Open(context.TODO(), cfg.GraphDriver, dawgs.Config{ + DriverCfg: cfg.Database.PostgreSQLConnectionString(), + }) + + case neo4j.DriverName: + graphDatabase, err = dawgs.Open(context.TODO(), cfg.GraphDriver, dawgs.Config{ + DriverCfg: cfg.Neo4J.Neo4jConnectionString(), + }) + + default: + testCtrl.Fatalf("unsupported graph driver name %s", cfg.GraphDriver) + } -func OpenNeo4jGraphDB(testCtrl test.Controller) graph.Database { - graphDatabase, err := dawgs.Open(neo4j.DriverName, dawgs.Config{DriverCfg: LoadConfiguration(testCtrl).Neo4J.Neo4jConnectionString()}) - require.Nilf(testCtrl, err, "Failed connecting to graph database: %v", err) + test.RequireNilErrf(testCtrl, err, "Failed connecting to graph database: %v", err) + test.RequireNilErr(testCtrl, graphDatabase.AssertSchema(context.Background(), schema.DefaultGraphSchema())) return graphDatabase } diff --git a/cmd/api/src/test/integration/graph.go b/cmd/api/src/test/integration/graph.go index 7b0ea17653..b5acb3632a 100644 --- a/cmd/api/src/test/integration/graph.go +++ b/cmd/api/src/test/integration/graph.go @@ -17,269 +17,156 @@ package integration import ( - "context" "fmt" "strings" "time" - "github.com/specterops/bloodhound/dawgs/cardinality" _ "github.com/specterops/bloodhound/dawgs/drivers/neo4j" "github.com/specterops/bloodhound/dawgs/graph" - "github.com/specterops/bloodhound/dawgs/ops" "github.com/specterops/bloodhound/graphschema/ad" "github.com/specterops/bloodhound/graphschema/azure" "github.com/specterops/bloodhound/graphschema/common" - "github.com/specterops/bloodhound/log" "github.com/specterops/bloodhound/src/test" "github.com/specterops/bloodhound/src/test/must" - "github.com/stretchr/testify/require" ) var DefaultRelProperties = graph.AsProperties(graph.PropertyMap{ common.LastSeen: time.Now().Format(time.RFC3339), }) -func NewGraphTestContext(testCtrl test.Controller) *GraphTestContext { - testCtx := &GraphTestContext{ - testCtrl: testCtrl, - nodesCreated: cardinality.NewBitmap32(), - GraphDB: OpenNeo4jGraphDB(testCtrl), - } +func NewGraphTestContext(testCtrl test.Controller, schema graph.Schema) *GraphTestContext { + testCtx := test.NewContext(testCtrl) - testCtrl.Cleanup(testCtx.Cleanup) - return testCtx + return &GraphTestContext{ + testCtx: testCtx, + Graph: NewGraphContext(testCtx, schema), + } } type GraphTestContext struct { - testCtrl test.Controller - tx graph.Transaction - nodesCreated cardinality.Duplex[uint32] - Harness HarnessDetails - GraphDB graph.Database + testCtx test.Context + Harness HarnessDetails + Graph *GraphContext +} + +// TODO: This is a responsibility violation +func (s *GraphTestContext) Context() test.Context { + return s.testCtx } func (s *GraphTestContext) NodeObjectID(node *graph.Node) string { objectID, err := node.Properties.Get(common.ObjectID.String()).String() - require.Nilf(s.testCtrl, err, "Expected node %d to have a valid %s property: %v", node.ID, common.ObjectID.String(), err) + + test.RequireNilErrf(s.testCtx, err, "expected node %d to have a valid %s property: %v", node.ID, common.ObjectID.String(), err) return objectID } func (s *GraphTestContext) FindNode(criteria graph.Criteria) *graph.Node { - var node *graph.Node + var ( + node *graph.Node + err error + ) - require.Nil(s.testCtrl, s.GraphDB.ReadTransaction(context.Background(), func(tx graph.Transaction) error { - fetchedNode, err := tx.Nodes().Filter(criteria).First() - node = fetchedNode + s.Graph.ReadTransaction(s.testCtx, func(tx graph.Transaction) error { + node, err = tx.Nodes().Filter(criteria).First() return err - })) + }) return node } func (s *GraphTestContext) UpdateNode(node *graph.Node) { - require.Nil(s.testCtrl, s.tx.UpdateNode(node)) -} - -func (s *GraphTestContext) Cleanup() { - if err := s.GraphDB.BatchOperation(context.Background(), func(batch graph.Batch) error { - return s.nodesCreated.Each(func(nodeID uint32) (bool, error) { - if err := batch.DeleteNode(graph.ID(nodeID)); err != nil { - return false, err - } - - return true, nil - }) - }); err != nil { - s.testCtrl.Errorf("Failed to clear DB after tests: %v", err) - } + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + return tx.UpdateNode(node) + }) } -func (s *GraphTestContext) EmptyDatabaseTest(dbDelegate func(harness HarnessDetails, db graph.Database) error) { - log.ConfigureDefaults() +func (s *GraphTestContext) DatabaseTest(dbDelegate func(harness HarnessDetails, db graph.Database)) { + s.setupActiveDirectory() + s.setupAzure() - require.Nil(s.testCtrl, dbDelegate(s.Harness, s.GraphDB)) + dbDelegate(s.Harness, s.Graph.Database) } -func (s *GraphTestContext) DatabaseTest(dbDelegate func(harness HarnessDetails, db graph.Database) error) { - log.ConfigureDefaults() - - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() - - if err := tx.Nodes().Delete(); err != nil { - return err - } - - s.setupActiveDirectory() - s.setupAzure() - return nil - })) - - require.Nil(s.testCtrl, dbDelegate(s.Harness, s.GraphDB)) +func (s *GraphTestContext) SetupHarness(setup func(harness *HarnessDetails) error) { + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + return setup(&s.Harness) + }) } -func (s *GraphTestContext) DatabaseTestWithSetup(setup func(harness *HarnessDetails), dbDelegate func(harness HarnessDetails, db graph.Database) error) { - log.Configure(&log.Configuration{ - Level: log.LevelDebug, +func (s *GraphTestContext) DatabaseTestWithSetup(setup func(harness *HarnessDetails) error, dbDelegate func(harness HarnessDetails, db graph.Database)) { + // Wipe the DB before executing the test + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + return tx.Nodes().Delete() }) - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() - - if err := tx.Nodes().Delete(); err != nil { - return err - } - - setup(&s.Harness) - return nil - })) + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + return setup(&s.Harness) + }) - require.Nil(s.testCtrl, dbDelegate(s.Harness, s.GraphDB)) + dbDelegate(s.Harness, s.Graph.Database) } func (s *GraphTestContext) BatchTest(batchDelegate func(harness HarnessDetails, batch graph.Batch), assertionDelegate func(details HarnessDetails, tx graph.Transaction)) { - log.ConfigureDefaults() - - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() + s.setupActiveDirectory() + s.setupAzure() - if err := tx.Nodes().Delete(); err != nil { - return err - } - - s.setupActiveDirectory() - s.setupAzure() - return nil - })) - - require.Nil(s.testCtrl, s.GraphDB.BatchOperation(context.Background(), func(batch graph.Batch) error { + s.Graph.BatchOperation(s.testCtx, func(batch graph.Batch) error { batchDelegate(s.Harness, batch) return nil - })) + }) - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { + s.Graph.ReadTransaction(s.testCtx, func(tx graph.Transaction) error { assertionDelegate(s.Harness, tx) return nil - })) + }) } func (s *GraphTestContext) TransactionalTest(txDelegate func(harness HarnessDetails, tx graph.Transaction)) { - log.ConfigureDefaults() - - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() - - if err := tx.Nodes().Delete(); err != nil { - return err - } - - s.setupActiveDirectory() - s.setupAzure() - return nil - })) - - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() + s.setupActiveDirectory() + s.setupAzure() + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { txDelegate(s.Harness, tx) return nil - })) + }) } -func (s *GraphTestContext) ReadTransactionTest(setup func(harness *HarnessDetails), txDelegate func(harness HarnessDetails, tx graph.Transaction)) { - log.ConfigureDefaults() - - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() - - if err := tx.Nodes().Delete(); err != nil { - return err - } - - setup(&s.Harness) - return nil - })) - - require.Nil(s.testCtrl, s.GraphDB.ReadTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() +func (s *GraphTestContext) ReadTransactionTestWithSetup(setup func(harness *HarnessDetails) error, txDelegate func(harness HarnessDetails, tx graph.Transaction)) { + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + return setup(&s.Harness) + }) + s.Graph.ReadTransaction(s.testCtx, func(tx graph.Transaction) error { txDelegate(s.Harness, tx) return nil - })) + }) } -func (s *GraphTestContext) WriteTransactionTest(setup func(harness *HarnessDetails), txDelegate func(harness HarnessDetails, tx graph.Transaction)) { - log.ConfigureDefaults() - - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() - - if err := tx.Nodes().Delete(); err != nil { - return err - } - - setup(&s.Harness) - return nil - })) - - require.Nil(s.testCtrl, s.GraphDB.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - s.tx = tx - - defer func() { - s.tx = nil - }() +func (s *GraphTestContext) WriteTransactionTestWithSetup(setup func(harness *HarnessDetails) error, txDelegate func(harness HarnessDetails, tx graph.Transaction)) { + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + return setup(&s.Harness) + }) + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { txDelegate(s.Harness, tx) return nil - })) -} - -func (s *GraphTestContext) DeleteNode(tx graph.Transaction, target *graph.Node) { - err := ops.DeleteNodes(tx, target.ID) - require.Nilf(s.testCtrl, err, "Error deleting node: %v", err) - - s.nodesCreated.Remove(target.ID.Uint32()) + }) } func (s *GraphTestContext) NewNode(properties *graph.Properties, kinds ...graph.Kind) *graph.Node { - newNode, err := s.tx.CreateNode(properties, kinds...) - require.Nilf(s.testCtrl, err, "Error creating node: %v", err) + var ( + node *graph.Node + err error + ) + + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + node, err = tx.CreateNode(properties, kinds...) + return err + }) - s.nodesCreated.Add(newNode.ID.Uint32()) - return newNode + return node } func (s *GraphTestContext) NewAzureApplication(name, objectID, tenantID string) *graph.Node { @@ -377,25 +264,23 @@ func (s *GraphTestContext) NewAzureSubscription(name, objectID, tenantID string) }), azure.Entity, azure.Subscription) } -func (s *GraphTestContext) NewRelationship(startNode, endNode *graph.Node, kind graph.Kind, properties ...*graph.Properties) *graph.Relationship { - var nodeProperties *graph.Properties - - if len(properties) > 0 { - nodeProperties = properties[0] +func (s *GraphTestContext) NewRelationship(startNode, endNode *graph.Node, kind graph.Kind, propertyBags ...*graph.Properties) *graph.Relationship { + var ( + relationshipProperties = graph.NewPropertiesRed() + relationship *graph.Relationship + err error + ) - if len(properties) > 1 { - for _, additionalProperties := range properties[1:] { - nodeProperties.SetAll(additionalProperties.Map) - } - } - } else { - nodeProperties = graph.NewProperties() + for _, additionalProperties := range propertyBags { + relationshipProperties.Merge(additionalProperties) } - newRelationship, err := s.tx.CreateRelationship(startNode, endNode, kind, nodeProperties) + s.Graph.WriteTransaction(s.testCtx, func(tx graph.Transaction) error { + relationship, err = tx.CreateRelationshipByIDs(startNode.ID, endNode.ID, kind, relationshipProperties) + return err + }) - require.Nil(s.testCtrl, err, fmt.Sprintf("error: %v", err)) - return newRelationship + return relationship } func (s *GraphTestContext) CreateAzureRelatedRoles(root *graph.Node, tenantID string, numRoles int) graph.NodeSet { diff --git a/cmd/api/src/test/integration/harnesses.go b/cmd/api/src/test/integration/harnesses.go index 4d387108c2..2b5b74a6b5 100644 --- a/cmd/api/src/test/integration/harnesses.go +++ b/cmd/api/src/test/integration/harnesses.go @@ -259,7 +259,7 @@ type OUContainedHarness struct { } func (s *OUContainedHarness) Setup(testCtx *GraphTestContext) { - s.Domain = testCtx.NewActiveDirectoryDomain("Domain", RandomObjectID(testCtx.testCtrl), false, true) + s.Domain = testCtx.NewActiveDirectoryDomain("Domain", RandomObjectID(testCtx.testCtx), false, true) s.OUA = testCtx.NewActiveDirectoryOU("OUA", testCtx.Harness.RootADHarness.ActiveDirectoryDomainSID, false) s.OUB = testCtx.NewActiveDirectoryOU("OUB", testCtx.Harness.RootADHarness.ActiveDirectoryDomainSID, false) s.OUC = testCtx.NewActiveDirectoryOU("OUC", testCtx.Harness.RootADHarness.ActiveDirectoryDomainSID, false) @@ -318,8 +318,8 @@ type AssetGroupComboNodeHarness struct { } func (s *AssetGroupComboNodeHarness) Setup(testCtx *GraphTestContext) { - s.GroupA = testCtx.NewActiveDirectoryGroup("GroupA", RandomObjectID(testCtx.testCtrl)) - s.GroupB = testCtx.NewActiveDirectoryGroup("GroupB", RandomObjectID(testCtx.testCtrl)) + s.GroupA = testCtx.NewActiveDirectoryGroup("GroupA", RandomObjectID(testCtx.testCtx)) + s.GroupB = testCtx.NewActiveDirectoryGroup("GroupB", RandomObjectID(testCtx.testCtx)) s.GroupB.Properties.Set(common.SystemTags.String(), ad.AdminTierZero) testCtx.UpdateNode(s.GroupB) @@ -644,13 +644,13 @@ func (s *AZBaseHarness) Setup(testCtx *GraphTestContext) { numGroups = 5 numRoles = 5 ) - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Nodes = graph.NewNodeKindSet() s.Tenant = testCtx.NewAzureTenant(tenantID) - s.User = testCtx.NewAzureUser(HarnessUserName, HarnessUserName, HarnessUserDescription, RandomObjectID(testCtx.testCtrl), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) - s.Application = testCtx.NewAzureApplication(HarnessAppName, RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal(HarnessServicePrincipalName, RandomObjectID(testCtx.testCtrl), tenantID) + s.User = testCtx.NewAzureUser(HarnessUserName, HarnessUserName, HarnessUserDescription, RandomObjectID(testCtx.testCtx), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) + s.Application = testCtx.NewAzureApplication(HarnessAppName, RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal(HarnessServicePrincipalName, RandomObjectID(testCtx.testCtx), tenantID) s.Nodes.Add(s.Tenant, s.User, s.Application, s.ServicePrincipal) s.UserFirstDegreeGroups = graph.NewNodeSet() s.NumPaths = 1287 @@ -673,7 +673,7 @@ func (s *AZBaseHarness) Setup(testCtx *GraphTestContext) { // Create some VMs that the user has access to for vmIdx := 0; vmIdx < numVMs; vmIdx++ { - newVM := testCtx.NewAzureVM(fmt.Sprintf("vm %d", vmIdx), RandomObjectID(testCtx.testCtrl), tenantID) + newVM := testCtx.NewAzureVM(fmt.Sprintf("vm %d", vmIdx), RandomObjectID(testCtx.testCtx), tenantID) s.Nodes.Add(newVM) // Tie the vm to the tenant @@ -686,8 +686,8 @@ func (s *AZBaseHarness) Setup(testCtx *GraphTestContext) { // Create some role assignments for the user for roleIdx := 0; roleIdx < numRoles; roleIdx++ { var ( - objectID = RandomObjectID(testCtx.testCtrl) - roleTemplateID = RandomObjectID(testCtx.testCtrl) + objectID = RandomObjectID(testCtx.testCtx) + roleTemplateID = RandomObjectID(testCtx.testCtx) newRole = testCtx.NewAzureRole(fmt.Sprintf("AZRole_%s", objectID), objectID, roleTemplateID, tenantID) ) s.Nodes.Add(newRole) @@ -721,7 +721,7 @@ func (s *AZBaseHarness) CreateAzureNestedGroupChain(testCtx *GraphTestContext, t for groupIdx := 0; groupIdx < chainDepth; groupIdx++ { var ( - objectID = RandomObjectID(testCtx.testCtrl) + objectID = RandomObjectID(testCtx.testCtx) newGroup = testCtx.NewAzureGroup(fmt.Sprintf("AZGroup_%s", objectID), objectID, tenantID) ) @@ -748,11 +748,11 @@ type AZGroupMembershipHarness struct { } func (s *AZGroupMembershipHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) - s.UserA = testCtx.NewAzureUser("UserA", "UserA", "", RandomObjectID(testCtx.testCtrl), "", tenantID, false) - s.UserB = testCtx.NewAzureUser("UserB", "UserB", "", RandomObjectID(testCtx.testCtrl), "", tenantID, false) - s.UserC = testCtx.NewAzureUser("UserC", "UserC", "", RandomObjectID(testCtx.testCtrl), "", tenantID, false) - s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtrl), tenantID) + tenantID := RandomObjectID(testCtx.testCtx) + s.UserA = testCtx.NewAzureUser("UserA", "UserA", "", RandomObjectID(testCtx.testCtx), "", tenantID, false) + s.UserB = testCtx.NewAzureUser("UserB", "UserB", "", RandomObjectID(testCtx.testCtx), "", tenantID, false) + s.UserC = testCtx.NewAzureUser("UserC", "UserC", "", RandomObjectID(testCtx.testCtx), "", tenantID, false) + s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.UserA, s.Group, azure.MemberOf) testCtx.NewRelationship(s.UserB, s.Group, azure.MemberOf) @@ -775,19 +775,19 @@ type AZEntityPanelHarness struct { } func (s *AZEntityPanelHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) - s.Application = testCtx.NewAzureApplication("App", RandomObjectID(testCtx.testCtrl), tenantID) - s.Device = testCtx.NewAzureDevice("Device", RandomObjectID(testCtx.testCtrl), RandomObjectID(testCtx.testCtrl), tenantID) - s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtrl), tenantID) - s.ManagementGroup = testCtx.NewAzureResourceGroup("Mgmt Group", RandomObjectID(testCtx.testCtrl), tenantID) - s.ResourceGroup = testCtx.NewAzureResourceGroup("Resource Group", RandomObjectID(testCtx.testCtrl), tenantID) - s.KeyVault = testCtx.NewAzureKeyVault("Key Vault", RandomObjectID(testCtx.testCtrl), tenantID) - s.Role = testCtx.NewAzureRole("Role", RandomObjectID(testCtx.testCtrl), RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) - s.Subscription = testCtx.NewAzureSubscription("Sub", RandomObjectID(testCtx.testCtrl), tenantID) + tenantID := RandomObjectID(testCtx.testCtx) + s.Application = testCtx.NewAzureApplication("App", RandomObjectID(testCtx.testCtx), tenantID) + s.Device = testCtx.NewAzureDevice("Device", RandomObjectID(testCtx.testCtx), RandomObjectID(testCtx.testCtx), tenantID) + s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtx), tenantID) + s.ManagementGroup = testCtx.NewAzureResourceGroup("Mgmt Group", RandomObjectID(testCtx.testCtx), tenantID) + s.ResourceGroup = testCtx.NewAzureResourceGroup("Resource Group", RandomObjectID(testCtx.testCtx), tenantID) + s.KeyVault = testCtx.NewAzureKeyVault("Key Vault", RandomObjectID(testCtx.testCtx), tenantID) + s.Role = testCtx.NewAzureRole("Role", RandomObjectID(testCtx.testCtx), RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) + s.Subscription = testCtx.NewAzureSubscription("Sub", RandomObjectID(testCtx.testCtx), tenantID) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.User = testCtx.NewAzureUser("User", "UserPrincipal", "Test User", RandomObjectID(testCtx.testCtrl), "Licenses", tenantID, false) - s.VM = testCtx.NewAzureVM("VM", RandomObjectID(testCtx.testCtrl), tenantID) + s.User = testCtx.NewAzureUser("User", "UserPrincipal", "Test User", RandomObjectID(testCtx.testCtx), "Licenses", tenantID, false) + s.VM = testCtx.NewAzureVM("VM", RandomObjectID(testCtx.testCtx), tenantID) // Application testCtx.NewRelationship(s.User, s.Application, azure.Owner) @@ -836,13 +836,13 @@ type AZMGApplicationReadWriteAllHarness struct { } func (s *AZMGApplicationReadWriteAllHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtrl), tenantID) + s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtx), tenantID) - s.Application = testCtx.NewAzureApplication("App", RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipalB = testCtx.NewAzureServicePrincipal("Service Principal B", RandomObjectID(testCtx.testCtrl), tenantID) + s.Application = testCtx.NewAzureApplication("App", RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipalB = testCtx.NewAzureServicePrincipal("Service Principal B", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.Tenant, s.MicrosoftGraph, azure.Contains) testCtx.NewRelationship(s.Tenant, s.Application, azure.Contains) @@ -868,11 +868,11 @@ type AZMGAppRoleManagementReadWriteAllHarness struct { } func (s *AZMGAppRoleManagementReadWriteAllHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtrl), tenantID) + s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtx), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.Tenant, s.MicrosoftGraph, azure.Contains) testCtx.NewRelationship(s.Tenant, s.ServicePrincipal, azure.Contains) @@ -890,12 +890,12 @@ type AZMGDirectoryReadWriteAllHarness struct { } func (s *AZMGDirectoryReadWriteAllHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtrl), tenantID) + s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtx), tenantID) - s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) + s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.Tenant, s.MicrosoftGraph, azure.Contains) testCtx.NewRelationship(s.Tenant, s.Group, azure.Contains) @@ -914,12 +914,12 @@ type AZMGGroupReadWriteAllHarness struct { } func (s *AZMGGroupReadWriteAllHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtrl), tenantID) + s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtx), tenantID) - s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) + s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.Tenant, s.MicrosoftGraph, azure.Contains) testCtx.NewRelationship(s.Tenant, s.Group, azure.Contains) @@ -938,12 +938,12 @@ type AZMGGroupMemberReadWriteAllHarness struct { } func (s *AZMGGroupMemberReadWriteAllHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtrl), tenantID) + s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtx), tenantID) - s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) + s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.Tenant, s.MicrosoftGraph, azure.Contains) testCtx.NewRelationship(s.Tenant, s.Group, azure.Contains) @@ -965,15 +965,15 @@ type AZMGRoleManagementReadWriteDirectoryHarness struct { } func (s *AZMGRoleManagementReadWriteDirectoryHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtrl), tenantID) + s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtx), tenantID) - s.Application = testCtx.NewAzureApplication("App", RandomObjectID(testCtx.testCtrl), tenantID) - s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtrl), tenantID) - s.Role = testCtx.NewAzureRole("Role", RandomObjectID(testCtx.testCtrl), RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipalB = testCtx.NewAzureServicePrincipal("Service Principal B", RandomObjectID(testCtx.testCtrl), tenantID) + s.Application = testCtx.NewAzureApplication("App", RandomObjectID(testCtx.testCtx), tenantID) + s.Group = testCtx.NewAzureGroup("Group", RandomObjectID(testCtx.testCtx), tenantID) + s.Role = testCtx.NewAzureRole("Role", RandomObjectID(testCtx.testCtx), RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipalB = testCtx.NewAzureServicePrincipal("Service Principal B", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.Tenant, s.MicrosoftGraph, azure.Contains) testCtx.NewRelationship(s.Tenant, s.Application, azure.Contains) @@ -1007,12 +1007,12 @@ type AZMGServicePrincipalEndpointReadWriteAllHarness struct { } func (s *AZMGServicePrincipalEndpointReadWriteAllHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) + tenantID := RandomObjectID(testCtx.testCtx) s.Tenant = testCtx.NewAzureTenant(tenantID) - s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtrl), tenantID) + s.MicrosoftGraph = testCtx.NewAzureServicePrincipal("Microsoft Graph", RandomObjectID(testCtx.testCtx), tenantID) - s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtrl), tenantID) - s.ServicePrincipalB = testCtx.NewAzureServicePrincipal("Service Principal B", RandomObjectID(testCtx.testCtrl), tenantID) + s.ServicePrincipal = testCtx.NewAzureServicePrincipal("Service Principal", RandomObjectID(testCtx.testCtx), tenantID) + s.ServicePrincipalB = testCtx.NewAzureServicePrincipal("Service Principal B", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.Tenant, s.MicrosoftGraph, azure.Contains) testCtx.NewRelationship(s.Tenant, s.ServicePrincipal, azure.Contains) @@ -1037,15 +1037,15 @@ type AZInboundControlHarness struct { } func (s *AZInboundControlHarness) Setup(testCtx *GraphTestContext) { - tenantID := RandomObjectID(testCtx.testCtrl) - s.ControlledAZUser = testCtx.NewAzureUser("Controlled AZUser", "Controlled AZUser", "", RandomObjectID(testCtx.testCtrl), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) - s.AZAppA = testCtx.NewAzureApplication("AZAppA", RandomObjectID(testCtx.testCtrl), tenantID) - s.AZGroupA = testCtx.NewAzureGroup("AZGroupA", RandomObjectID(testCtx.testCtrl), tenantID) - s.AZGroupB = testCtx.NewAzureGroup("AZGroupB", RandomObjectID(testCtx.testCtrl), tenantID) - s.AZUserA = testCtx.NewAzureUser("AZUserA", "AZUserA", "", RandomObjectID(testCtx.testCtrl), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) - s.AZUserB = testCtx.NewAzureUser("AZUserB", "AZUserB", "", RandomObjectID(testCtx.testCtrl), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) - s.AZServicePrincipalA = testCtx.NewAzureServicePrincipal("AZServicePrincipalA", RandomObjectID(testCtx.testCtrl), tenantID) - s.AZServicePrincipalB = testCtx.NewAzureServicePrincipal("AZServicePrincipalB", RandomObjectID(testCtx.testCtrl), tenantID) + tenantID := RandomObjectID(testCtx.testCtx) + s.ControlledAZUser = testCtx.NewAzureUser("Controlled AZUser", "Controlled AZUser", "", RandomObjectID(testCtx.testCtx), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) + s.AZAppA = testCtx.NewAzureApplication("AZAppA", RandomObjectID(testCtx.testCtx), tenantID) + s.AZGroupA = testCtx.NewAzureGroup("AZGroupA", RandomObjectID(testCtx.testCtx), tenantID) + s.AZGroupB = testCtx.NewAzureGroup("AZGroupB", RandomObjectID(testCtx.testCtx), tenantID) + s.AZUserA = testCtx.NewAzureUser("AZUserA", "AZUserA", "", RandomObjectID(testCtx.testCtx), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) + s.AZUserB = testCtx.NewAzureUser("AZUserB", "AZUserB", "", RandomObjectID(testCtx.testCtx), HarnessUserLicenses, tenantID, HarnessUserMFAEnabled) + s.AZServicePrincipalA = testCtx.NewAzureServicePrincipal("AZServicePrincipalA", RandomObjectID(testCtx.testCtx), tenantID) + s.AZServicePrincipalB = testCtx.NewAzureServicePrincipal("AZServicePrincipalB", RandomObjectID(testCtx.testCtx), tenantID) testCtx.NewRelationship(s.AZUserA, s.AZGroupA, azure.MemberOf) testCtx.NewRelationship(s.AZServicePrincipalB, s.AZGroupB, azure.MemberOf) @@ -1662,7 +1662,6 @@ func (s *ShortcutHarness) Setup(graphTestContext *GraphTestContext) { } type RootADHarness struct { - TierZero graph.NodeSet ActiveDirectoryDomainSID string ActiveDirectoryDomain *graph.Node ActiveDirectoryRDPDomainGroup *graph.Node @@ -1670,7 +1669,6 @@ type RootADHarness struct { ActiveDirectoryUser *graph.Node ActiveDirectoryOU *graph.Node ActiveDirectoryGPO *graph.Node - ActiveDirectoryDCSyncMetaRelationship *graph.Relationship ActiveDirectoryDCSyncAtomicRelationship *graph.Relationship NumCollectedDomains int } diff --git a/packages/go/cypher/frontend/parse_test.go b/cmd/api/src/test/integration/server.go similarity index 65% rename from packages/go/cypher/frontend/parse_test.go rename to cmd/api/src/test/integration/server.go index 68b82ec5aa..12ea05df3b 100644 --- a/packages/go/cypher/frontend/parse_test.go +++ b/cmd/api/src/test/integration/server.go @@ -1,31 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 -package frontend_test - -import ( - "testing" - - "github.com/specterops/bloodhound/cypher/test" -) - -func TestParseCypher_HappyPath(t *testing.T) { - test.LoadFixture(t, test.PositiveTestCases).Run(t) -} - -func TestParseCypher_NegativeCases(t *testing.T) { - test.LoadFixture(t, test.NegativeTestCases).Run(t) -} +package integration diff --git a/cmd/api/src/test/lab/fixtures/api.go b/cmd/api/src/test/lab/fixtures/api.go index f21289adb8..523de11bf4 100644 --- a/cmd/api/src/test/lab/fixtures/api.go +++ b/cmd/api/src/test/lab/fixtures/api.go @@ -1,97 +1,90 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package fixtures import ( "context" - "errors" "fmt" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/src/bootstrap" + "github.com/specterops/bloodhound/src/config" + "github.com/specterops/bloodhound/src/daemons" + "github.com/specterops/bloodhound/src/database" + "github.com/specterops/bloodhound/src/services" "log" "net/http" "net/url" "sync" "time" - "github.com/specterops/bloodhound/src/api/v2/integration" - "github.com/specterops/bloodhound/src/server" "github.com/specterops/bloodhound/lab" ) -var BHApiFixture = NewApiFixture(integration.StartBHServer) +var BHApiFixture = NewApiFixture() -func NewApiFixture(startFn integration.APIStartFunc) *lab.Fixture[integration.APIServerContext] { +func NewApiFixture() *lab.Fixture[bool] { var ( - ctx context.Context - cancel context.CancelFunc - wg *sync.WaitGroup - serverErr error - - dependencyErrs = make([]error, 0) - fixture = lab.NewFixture(func(harness *lab.Harness) (integration.APIServerContext, error) { - ctx, cancel = context.WithCancel(context.Background()) - wg = &sync.WaitGroup{} - out := integration.APIServerContext{} - if config, ok := lab.Unpack(harness, ConfigFixture); !ok { - return out, fmt.Errorf("unable to unpack ConfigFixture") - } else if err := server.EnsureServerDirectories(config); err != nil { - return out, err - } else if pgdb, ok := lab.Unpack(harness, PostgresFixture); !ok { - return out, fmt.Errorf("unable to unpack PostgresFixture") - } else if graphdb, ok := lab.Unpack(harness, GraphDBFixture); !ok { - return out, fmt.Errorf("unable to unpack GraphDBFixture") - } else if graphcache, ok := lab.Unpack(harness, GraphCacheFixture); !ok { - return out, fmt.Errorf("unable to unpack GraphCacheFixture") - } else if apicache, ok := lab.Unpack(harness, ApiCacheFixture); !ok { - return out, fmt.Errorf("unable to unpack GraphCacheFixture") + ctx, cancel = context.WithCancel(context.Background()) + wg = &sync.WaitGroup{} + serverErr error + + fixture = lab.NewFixture(func(harness *lab.Harness) (bool, error) { + if cfg, ok := lab.Unpack(harness, ConfigFixture); !ok { + return false, fmt.Errorf("unable to unpack ConfigFixture") } else { - out.Context = ctx - out.DB = pgdb - out.GraphDB = graphdb - out.Configuration = config - out.APICache = apicache - out.GraphQueryCache = graphcache // Start the server wg.Add(1) + go func() { defer wg.Done() - serverErr = startFn(out) + + initializer := bootstrap.Initializer[*database.BloodhoundDB, *graph.DatabaseSwitch]{ + Configuration: cfg, + DBConnector: services.ConnectDatabases, + Entrypoint: func(ctx context.Context, cfg config.Configuration, databaseConnections bootstrap.DatabaseConnections[*database.BloodhoundDB, *graph.DatabaseSwitch]) ([]daemons.Daemon, error) { + if err := databaseConnections.RDMS.Wipe(); err != nil { + return nil, err + } + + return services.Entrypoint(ctx, cfg, databaseConnections) + }, + } + + if err := initializer.Launch(ctx, false); err != nil { + serverErr = err + } }() - if err := waitForAPI(30*time.Second, config.RootURL.String()); err != nil { - return out, err + if err := waitForAPI(30*time.Second, cfg.RootURL.String()); err != nil { + return false, err } else { - return out, nil + return true, nil } } - }, func(harness *lab.Harness, apiServerCtx integration.APIServerContext) error { + }, func(harness *lab.Harness, started bool) error { cancel() wg.Wait() + return serverErr }) ) - dependencyErrs = append(dependencyErrs, lab.SetDependency(fixture, ConfigFixture)) - dependencyErrs = append(dependencyErrs, lab.SetDependency(fixture, PostgresFixture)) - dependencyErrs = append(dependencyErrs, lab.SetDependency(fixture, GraphDBFixture)) - dependencyErrs = append(dependencyErrs, lab.SetDependency(fixture, ApiCacheFixture)) - dependencyErrs = append(dependencyErrs, lab.SetDependency(fixture, GraphCacheFixture)) - - if err := errors.Join(dependencyErrs...); err != nil { - log.Fatalf("Errors encountered while setting up dependencies:\n%v\n", err) + if err := lab.SetDependency(fixture, ConfigFixture); err != nil { + log.Fatalf("BHApiFixture dependency error: %v", err) } return fixture diff --git a/cmd/api/src/test/lab/fixtures/apiclient.go b/cmd/api/src/test/lab/fixtures/apiclient.go index cfba8d3039..a75ac3a09b 100644 --- a/cmd/api/src/test/lab/fixtures/apiclient.go +++ b/cmd/api/src/test/lab/fixtures/apiclient.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package fixtures @@ -20,14 +20,14 @@ import ( "fmt" "log" + "github.com/specterops/bloodhound/lab" "github.com/specterops/bloodhound/src/api/v2/apiclient" "github.com/specterops/bloodhound/src/api/v2/integration" - "github.com/specterops/bloodhound/lab" ) var BHApiClientFixture = NewApiClientFixture(BHApiFixture) -func NewApiClientFixture(apiFixture *lab.Fixture[integration.APIServerContext]) *lab.Fixture[apiclient.Client] { +func NewApiClientFixture(apiFixture *lab.Fixture[bool]) *lab.Fixture[apiclient.Client] { fixture := lab.NewFixture(func(harness *lab.Harness) (apiclient.Client, error) { if config, ok := lab.Unpack(harness, ConfigFixture); !ok { return apiclient.Client{}, fmt.Errorf("unable to unpack ConfigFixture") diff --git a/cmd/api/src/test/lab/fixtures/graphdb.go b/cmd/api/src/test/lab/fixtures/graphdb.go index 25d9420ba9..db1120324d 100644 --- a/cmd/api/src/test/lab/fixtures/graphdb.go +++ b/cmd/api/src/test/lab/fixtures/graphdb.go @@ -17,32 +17,34 @@ package fixtures import ( + "context" "fmt" + schema "github.com/specterops/bloodhound/graphschema" "log" - "github.com/specterops/bloodhound/dawgs" - "github.com/specterops/bloodhound/dawgs/drivers/neo4j" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/lab" - "github.com/specterops/bloodhound/src/server" + "github.com/specterops/bloodhound/src/bootstrap" ) var GraphDBFixture = NewGraphDBFixture() -func NewGraphDBFixture() *lab.Fixture[graph.Database] { - fixture := lab.NewFixture(func(harness *lab.Harness) (graph.Database, error) { +func NewGraphDBFixture() *lab.Fixture[*graph.DatabaseSwitch] { + fixture := lab.NewFixture(func(harness *lab.Harness) (*graph.DatabaseSwitch, error) { if config, ok := lab.Unpack(harness, ConfigFixture); !ok { return nil, fmt.Errorf("unable to unpack ConfigFixture") - } else if graphdb, err := dawgs.Open(neo4j.DriverName, dawgs.Config{DriverCfg: config.Neo4J.Neo4jConnectionString()}); err != nil { - return graphdb, err - } else if err := server.MigrateGraph(config, graphdb); err != nil { - return graphdb, fmt.Errorf("failed migrating Graph database: %v", err) + } else if graphdb, err := bootstrap.ConnectGraph(context.TODO(), config); err != nil { + return nil, err + } else if err := bootstrap.MigrateGraph(context.Background(), graphdb, schema.DefaultGraphSchema()); err != nil { + return nil, fmt.Errorf("failed migrating Graph database: %v", err) } else { - return graphdb, nil + return graph.NewDatabaseSwitch(context.Background(), graphdb), nil } }, nil) + if err := lab.SetDependency(fixture, ConfigFixture); err != nil { log.Fatalln(err) } + return fixture } diff --git a/cmd/api/src/test/lab/fixtures/postgres.go b/cmd/api/src/test/lab/fixtures/postgres.go index 40a6aed05a..d7768bdce9 100644 --- a/cmd/api/src/test/lab/fixtures/postgres.go +++ b/cmd/api/src/test/lab/fixtures/postgres.go @@ -23,8 +23,8 @@ import ( "github.com/specterops/bloodhound/src/auth" "github.com/specterops/bloodhound/lab" + "github.com/specterops/bloodhound/src/bootstrap" "github.com/specterops/bloodhound/src/database" - "github.com/specterops/bloodhound/src/server" "github.com/specterops/bloodhound/src/test/integration" ) @@ -35,7 +35,7 @@ var PostgresFixture = lab.NewFixture(func(harness *lab.Harness) (*database.Blood return nil, err } else if err := integration.Prepare(database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver())); err != nil { return nil, fmt.Errorf("failed ensuring database: %v", err) - } else if err := server.MigrateDB(config, database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver())); err != nil { + } else if err := bootstrap.MigrateDB(config, database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver())); err != nil { return nil, fmt.Errorf("failed migrating database: %v", err) } else { return database.NewBloodhoundDB(pgdb, auth.NewIdentityResolver()), nil diff --git a/cmd/api/src/test/require.go b/cmd/api/src/test/require.go new file mode 100644 index 0000000000..cd8d7fa810 --- /dev/null +++ b/cmd/api/src/test/require.go @@ -0,0 +1,42 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package test + +import ( + "fmt" + "github.com/stretchr/testify/require" +) + +func RequireNilErr(t Controller, err error) { + errMsg := "" + + if err != nil { + errMsg = err.Error() + } + + require.Nilf(t, err, "Error must be nil but found %T: %s", err, errMsg) +} + +func RequireNilErrf(t Controller, err error, format string, parameters ...any) { + errMsg := "" + + if err != nil { + errMsg = fmt.Sprintf(format, parameters...) + } + + require.Nilf(t, err, "Error must be nil but found %T: %s", err, errMsg) +} diff --git a/cmd/api/src/utils/reflect.go b/cmd/api/src/utils/reflect.go index e6fb65726f..db0c1ba561 100644 --- a/cmd/api/src/utils/reflect.go +++ b/cmd/api/src/utils/reflect.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package utils diff --git a/go.work b/go.work index 34308b2f48..22dfbb4ede 100644 --- a/go.work +++ b/go.work @@ -14,7 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -go 1.20 +go 1.21 use ( ./cmd/api/src diff --git a/license_check.py b/license_check.py index 2ba99ba1e7..3177a9c64a 100755 --- a/license_check.py +++ b/license_check.py @@ -312,6 +312,7 @@ ".crt", ".key", ".example", + ".svg", ] # Any file listed below is included regardless of exclusions. @@ -360,6 +361,15 @@ def generate_license_header(comment_prefix: str) -> str: ".toml": generate_license_header("#"), } +# Below is a list of valid file headers that the license must be placed after +FILE_HEADER_PREFIXES = [ + # POSIX exec header + "#!", + + # XML header + " bool: matching_header = False @@ -386,6 +396,13 @@ def content_has_header(path: str, content_lines: List[str], header: str) -> bool return False +def _is_file_header(line: str) -> bool: + for header in FILE_HEADER_PREFIXES: + if line.startswith(header): + return True + return False + + def insert_license_header(path: str, header: str) -> None: with open(path, "r") as fin: content = fin.read() @@ -398,7 +415,7 @@ def insert_license_header(path: str, header: str) -> None: return # Try to find a script exec header to advance the line offset - line_offset = 1 if len(content_lines) > 0 and content_lines[0].startswith("#!") else 0 + line_offset = 1 if len(content_lines) > 0 and _is_file_header(content_lines[0]) else 0 for line in content_lines[line_offset:]: # Make sure to skip leading newlines since we'll add our own diff --git a/packages/cue/bh/ad/ad.cue b/packages/cue/bh/ad/ad.cue index c6922385a3..aace277a96 100644 --- a/packages/cue/bh/ad/ad.cue +++ b/packages/cue/bh/ad/ad.cue @@ -1134,5 +1134,6 @@ PathfindingRelationships: [ ADCSESC4, ADCSESC5, ADCSESC6, - ADCSESC7 + ADCSESC7, + DCFor ] diff --git a/packages/go/analysis/ad/ad.go b/packages/go/analysis/ad/ad.go index 76111c0015..1c0d029a56 100644 --- a/packages/go/analysis/ad/ad.go +++ b/packages/go/analysis/ad/ad.go @@ -76,33 +76,39 @@ func TierZeroWellKnownSIDSuffixes() []string { AdministratorsGroupSIDSuffix, } } -func FetchWellKnownTierZeroEntities(tx graph.Transaction, domainSID string) (graph.NodeSet, error) { + +func FetchWellKnownTierZeroEntities(ctx context.Context, db graph.Database, domainSID string) (graph.NodeSet, error) { + defer log.Measure(log.LevelInfo, "FetchWellKnownTierZeroEntities")() + nodes := graph.NewNodeSet() - for _, wellKnownSIDSuffix := range TierZeroWellKnownSIDSuffixes() { - if err := tx.Nodes().Filterf(func() graph.Criteria { - return query.And( - // Make sure we have the Group or User label. This should cover the case for URA as well as filter out all the other localgroups - query.KindIn(query.Node(), ad.Group, ad.User), - query.StringEndsWith(query.NodeProperty(common.ObjectID.String()), wellKnownSIDSuffix), - query.Equals(query.NodeProperty(ad.DomainSID.String()), domainSID), - ) - }).Fetch(func(cursor graph.Cursor[*graph.Node]) error { - for node := range cursor.Chan() { - nodes.Add(node) - } + return nodes, db.ReadTransaction(ctx, func(tx graph.Transaction) error { + for _, wellKnownSIDSuffix := range TierZeroWellKnownSIDSuffixes() { + if err := tx.Nodes().Filterf(func() graph.Criteria { + return query.And( + // Make sure we have the Group or User label. This should cover the case for URA as well as filter out all the other localgroups + query.KindIn(query.Node(), ad.Group, ad.User), + query.StringEndsWith(query.NodeProperty(common.ObjectID.String()), wellKnownSIDSuffix), + query.Equals(query.NodeProperty(ad.DomainSID.String()), domainSID), + ) + }).Fetch(func(cursor graph.Cursor[*graph.Node]) error { + for node := range cursor.Chan() { + nodes.Add(node) + } - return cursor.Error() - }); err != nil { - return nil, err + return cursor.Error() + }); err != nil { + return err + } } - } - return nodes, nil + return nil + }) } func FixWellKnownNodeTypes(ctx context.Context, db graph.Database) error { defer log.Measure(log.LevelInfo, "Fix well known node types")() + groupSuffixes := []string{EnterpriseKeyAdminsGroupSIDSuffix, KeyAdminsGroupSIDSuffix, EnterpriseDomainControllersGroupSIDSuffix, @@ -689,7 +695,7 @@ func GetADCSESC3EdgeComposition(ctx context.Context, db graph.Database, edge *gr //Find all cert templates we have EnrollOnBehalfOf from our first group of templates to prefilter again if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { - if p, err := ops.FetchPathSet(tx, tx.Relationships().Filter( + if p, err := ops.FetchPathSet(tx.Relationships().Filter( query.And( query.InIDs(query.StartID(), cardinality.DuplexToGraphIDs(path1CertTemplates)...), query.KindIn(query.Relationship(), ad.EnrollOnBehalfOf), @@ -786,7 +792,7 @@ func getDelegatedEnrollmentAgentPath(ctx context.Context, startNode, certTemplat var pathSet graph.PathSet return pathSet, db.ReadTransaction(ctx, func(tx graph.Transaction) error { - if paths, err := ops.FetchPathSet(tx, tx.Relationships().Filter(query.And( + if paths, err := ops.FetchPathSet(tx.Relationships().Filter(query.And( query.InIDs(query.StartID(), startNode.ID), query.InIDs(query.EndID(), certTemplate2.ID), query.KindIn(query.Relationship(), ad.DelegatedEnrollmentAgent), @@ -941,7 +947,7 @@ func getGoldenCertEdgeComposition(tx graph.Transaction, edge *graph.Relationship return finalPaths, err } else { //Find hosted enterprise CA - if ecaPaths, err := ops.FetchPathSet(tx, tx.Relationships().Filter(query.And( + if ecaPaths, err := ops.FetchPathSet(tx.Relationships().Filter(query.And( query.Equals(query.StartID(), startNode.ID), query.KindIn(query.End(), ad.EnterpriseCA), query.KindIn(query.Relationship(), ad.HostsCAService), diff --git a/packages/go/analysis/ad/filters.go b/packages/go/analysis/ad/filters.go index 7465128e72..667c8c87c2 100644 --- a/packages/go/analysis/ad/filters.go +++ b/packages/go/analysis/ad/filters.go @@ -163,6 +163,9 @@ func SelectComputersCandidateFilter(node *graph.Node) bool { func SelectGPOTierZeroCandidateFilter(node *graph.Node) bool { if tags, err := node.Properties.Get(common.SystemTags.String()).String(); err != nil { return false + } else if node.Kinds.ContainsOneOf(ad.Group) { + // GPOs don’t apply to groups. + return false } else { return strings.Contains(tags, ad.AdminTierZero) } diff --git a/packages/go/analysis/ad/filters_test.go b/packages/go/analysis/ad/filters_test.go new file mode 100644 index 0000000000..340a7d634b --- /dev/null +++ b/packages/go/analysis/ad/filters_test.go @@ -0,0 +1,38 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ad_test + +import ( + ad2 "github.com/specterops/bloodhound/analysis/ad" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/graphschema/ad" + "github.com/specterops/bloodhound/graphschema/common" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestSelectGPOContainerCandidateFilter(t *testing.T) { + var ( + computer = graph.NewNode(0, graph.NewProperties(), ad.Computer) + group = graph.NewNode(1, graph.NewProperties().Set(common.SystemTags.String(), ad.AdminTierZero), ad.Group) + user = graph.NewNode(2, graph.NewProperties().Set(common.SystemTags.String(), ad.AdminTierZero), ad.User) + ) + + assert.False(t, ad2.SelectGPOContainerCandidateFilter(computer)) + assert.False(t, ad2.SelectGPOTierZeroCandidateFilter(group)) + assert.True(t, ad2.SelectGPOTierZeroCandidateFilter(user)) +} diff --git a/packages/go/analysis/ad/queries.go b/packages/go/analysis/ad/queries.go index 3f14b21000..1293ea56f8 100644 --- a/packages/go/analysis/ad/queries.go +++ b/packages/go/analysis/ad/queries.go @@ -34,52 +34,65 @@ import ( "github.com/specterops/bloodhound/log" ) -func FetchGraphDBTierZeroTaggedAssets(tx graph.Transaction, domainSID string) (graph.NodeSet, error) { - if nodeSet, err := ops.FetchNodeSet(tx.Nodes().Filterf(func() graph.Criteria { - return query.And( - query.Kind(query.Node(), ad.Entity), - query.Equals(query.NodeProperty(ad.DomainSID.String()), domainSID), - query.StringContains(query.NodeProperty(common.SystemTags.String()), ad.AdminTierZero), - ) - })); err != nil { - return nil, err - } else { - return nodeSet, nil - } +func FetchGraphDBTierZeroTaggedAssets(ctx context.Context, db graph.Database, domainSID string) (graph.NodeSet, error) { + defer log.Measure(log.LevelInfo, "FetchGraphDBTierZeroTaggedAssets")() + + var ( + nodes graph.NodeSet + err error + ) + + return nodes, db.ReadTransaction(ctx, func(tx graph.Transaction) error { + nodes, err = ops.FetchNodeSet(tx.Nodes().Filterf(func() graph.Criteria { + return query.And( + query.Kind(query.Node(), ad.Entity), + query.Equals(query.NodeProperty(ad.DomainSID.String()), domainSID), + query.StringContains(query.NodeProperty(common.SystemTags.String()), ad.AdminTierZero), + ) + })) + + return err + }) + } -func FetchAllEnforcedGPOs(tx graph.Transaction, targets graph.NodeSet) (graph.NodeSet, error) { +func FetchAllEnforcedGPOs(ctx context.Context, db graph.Database, targets graph.NodeSet) (graph.NodeSet, error) { + defer log.Measure(log.LevelInfo, "FetchAllEnforcedGPOs")() + enforcedGPOs := graph.NewNodeSet() - for _, attackPathRoot := range targets { - if enforced, err := FetchEnforcedGPOs(tx, attackPathRoot, 0, 0); err != nil { - return nil, err - } else { - enforcedGPOs.AddSet(enforced) + return enforcedGPOs, db.ReadTransaction(ctx, func(tx graph.Transaction) error { + for _, attackPathRoot := range targets { + if enforced, err := FetchEnforcedGPOs(tx, attackPathRoot, 0, 0); err != nil { + return err + } else { + enforcedGPOs.AddSet(enforced) + } } - } - return enforcedGPOs, nil + return nil + }) } -func FetchAllDomains(ctx context.Context, db graph.Database) (graph.NodeSet, error) { +func FetchAllDomains(ctx context.Context, db graph.Database) ([]*graph.Node, error) { var ( - nodes graph.NodeSet + nodes []*graph.Node err error ) return nodes, db.ReadTransaction(ctx, func(tx graph.Transaction) error { - nodes, err = ops.FetchNodeSet(tx.Nodes().Filterf(func() graph.Criteria { + nodes, err = ops.FetchNodes(tx.Nodes().Filterf(func() graph.Criteria { return query.Kind(query.Node(), ad.Domain) - })) + }).OrderBy( + query.Order(query.NodeProperty(common.Name.String()), query.Descending()), + )) return err }) } -func FetchActiveDirectoryTierZeroRoots(tx graph.Transaction, domain *graph.Node) (graph.NodeSet, error) { - log.Infof("Fetching tier zero nodes for domain %d", domain.ID) - defer log.Measure(log.LevelInfo, "Finished fetching tier zero nodes for domain %d", domain.ID)() +func FetchActiveDirectoryTierZeroRoots(ctx context.Context, db graph.Database, domain *graph.Node) (graph.NodeSet, error) { + defer log.LogAndMeasure(log.LevelInfo, "FetchActiveDirectoryTierZeroRoots")() if domainSID, err := domain.Properties.Get(common.ObjectID.String()).String(); err != nil { return nil, err @@ -90,28 +103,28 @@ func FetchActiveDirectoryTierZeroRoots(tx graph.Transaction, domain *graph.Node) attackPathRoots.Add(domain) // Pull in custom tier zero tagged assets - if customTierZeroNodes, err := FetchGraphDBTierZeroTaggedAssets(tx, domainSID); err != nil { + if customTierZeroNodes, err := FetchGraphDBTierZeroTaggedAssets(ctx, db, domainSID); err != nil { return nil, err } else { attackPathRoots.AddSet(customTierZeroNodes) } // Pull in well known tier zero nodes by SID suffix - if wellKnownTierZeroNodes, err := FetchWellKnownTierZeroEntities(tx, domainSID); err != nil { + if wellKnownTierZeroNodes, err := FetchWellKnownTierZeroEntities(ctx, db, domainSID); err != nil { return nil, err } else { attackPathRoots.AddSet(wellKnownTierZeroNodes) } // Pull in all group members of attack path roots - if allGroupMembers, err := FetchAllGroupMembers(tx, attackPathRoots); err != nil { + if allGroupMembers, err := FetchAllGroupMembers(ctx, db, attackPathRoots); err != nil { return nil, err } else { attackPathRoots.AddSet(allGroupMembers) } // Add all enforced GPO nodes to the attack path roots - if enforcedGPOs, err := FetchAllEnforcedGPOs(tx, attackPathRoots); err != nil { + if enforcedGPOs, err := FetchAllEnforcedGPOs(ctx, db, attackPathRoots); err != nil { return nil, err } else { attackPathRoots.AddSet(enforcedGPOs) @@ -622,7 +635,7 @@ func FetchForeignGPOControllerPaths(tx graph.Transaction, node *graph.Node) (gra })); err != nil { return nil, err } else { - if directControllers, err := ops.FetchPathSet(tx, tx.Relationships().Filterf(func() graph.Criteria { + if directControllers, err := ops.FetchPathSet(tx.Relationships().Filterf(func() graph.Criteria { return query.And( query.InIDs(query.EndID(), gpoIDs...), query.KindIn(query.Relationship(), ad.ACLRelationships()...), @@ -716,7 +729,7 @@ func FetchForeignAdminPaths(tx graph.Transaction, node *graph.Node) (graph.PathS if domainSid, err := node.Properties.Get(ad.DomainSID.String()).String(); err != nil { return nil, err } else { - if directAdmins, err := ops.FetchPathSet(tx, tx.Relationships().Filterf(func() graph.Criteria { + if directAdmins, err := ops.FetchPathSet(tx.Relationships().Filterf(func() graph.Criteria { return query.And( query.Kind(query.End(), ad.Computer), query.Kind(query.Relationship(), ad.AdminTo), @@ -1197,15 +1210,32 @@ func FetchGroupMemberPaths(tx graph.Transaction, node *graph.Node) (graph.PathSe }) } -func FetchGroupMembers(tx graph.Transaction, root *graph.Node, skip, limit int) (graph.NodeSet, error) { - return ops.AcyclicTraverseNodes(tx, ops.TraversalPlan{ - Root: root, - Direction: graph.DirectionInbound, - Skip: skip, - Limit: limit, - BranchQuery: FilterGroupMembership, - }, func(node *graph.Node) bool { - return node.ID != root.ID +func FetchGroupMembers(ctx context.Context, db graph.Database, root *graph.Node, skip, limit int) (graph.NodeSet, error) { + collector := traversal.NewNodeCollector() + + if err := traversal.New(db, analysis.MaximumDatabaseParallelWorkers).BreadthFirst(ctx, traversal.Plan{ + Root: root, + Driver: traversal.LightweightDriver( + graph.DirectionInbound, + graphcache.New(), + query.Kind(query.Relationship(), ad.MemberOf), + traversal.AcyclicNodeFilter( + traversal.FilteredSkipLimit( + func(next *graph.PathSegment) (bool, bool) { + return true, next.Node.Kinds.ContainsOneOf(ad.Group) + }, + collector.Collect, + skip, + limit, + ), + ), + ), + }); err != nil { + return nil, err + } + + return collector.Nodes, db.ReadTransaction(ctx, func(tx graph.Transaction) error { + return ops.FetchAllNodeProperties(tx, collector.Nodes) }) } @@ -1352,12 +1382,16 @@ func FetchUserSessionCompleteness(tx graph.Transaction, domainSIDs ...string) (f } } -func FetchAllGroupMembers(tx graph.Transaction, targets graph.NodeSet) (graph.NodeSet, error) { +func FetchAllGroupMembers(ctx context.Context, db graph.Database, targets graph.NodeSet) (graph.NodeSet, error) { + defer log.Measure(log.LevelInfo, "FetchAllGroupMembers")() + + log.Infof("Fetching group members for %d AD nodes", len(targets)) + allGroupMembers := graph.NewNodeSet() for _, target := range targets { if target.Kinds.ContainsOneOf(ad.Group) { - if groupMembers, err := FetchGroupMembers(tx, target, 0, 0); err != nil { + if groupMembers, err := FetchGroupMembers(ctx, db, target, 0, 0); err != nil { return nil, err } else { allGroupMembers.AddSet(groupMembers) @@ -1365,6 +1399,7 @@ func FetchAllGroupMembers(tx graph.Transaction, targets graph.NodeSet) (graph.No } } + log.Infof("Collected %d group members", len(allGroupMembers)) return allGroupMembers, nil } diff --git a/packages/go/analysis/analysis.go b/packages/go/analysis/analysis.go index d0fa6bf16c..396e0273d1 100644 --- a/packages/go/analysis/analysis.go +++ b/packages/go/analysis/analysis.go @@ -44,9 +44,7 @@ var ( func AllTaggedNodesFilter(additionalFilter graph.Criteria) graph.Criteria { var ( filters = []graph.Criteria{ - query.Not( - query.Equals(query.NodeProperty(common.SystemTags.String()), ""), - ), + query.IsNotNull(query.NodeProperty(common.SystemTags.String())), } ) diff --git a/packages/go/analysis/azure/post.go b/packages/go/analysis/azure/post.go index fbe9a167b9..79857da96a 100644 --- a/packages/go/analysis/azure/post.go +++ b/packages/go/analysis/azure/post.go @@ -263,6 +263,8 @@ func (s RoleAssignments) NodeHasRole(id graph.ID, roleTemplateIDs ...string) boo // TenantRoles returns the NodeSet of roles for a given tenant that match one of the given role template IDs. If no role template ID is provided, then all of the tenant role nodes are returned in the NodeSet. func TenantRoles(tx graph.Transaction, tenant *graph.Node, roleTemplateIDs ...string) (graph.NodeSet, error) { + defer log.LogAndMeasure(log.LevelInfo, "Tenant %d TenantRoles", tenant.ID)() + if !IsTenantNode(tenant) { return nil, fmt.Errorf("cannot fetch tenant roles - node %d must be of kind %s", tenant.ID, azure.Tenant) } @@ -339,6 +341,8 @@ func roleMembers(tx graph.Transaction, tenantRoles graph.NodeSet, additionalRela // RoleMembersWithGrants returns the NodeSet of members for a given set of roles, including those members who may be able to grant themselves one of the given roles // NOTE: The current implementation also includes the role nodes in the returned set. It may be worth considering removing those nodes from the set if doing so doesn't break tier zero/high value assignment func RoleMembersWithGrants(tx graph.Transaction, tenant *graph.Node, roleTemplateIDs ...string) (graph.NodeSet, error) { + defer log.LogAndMeasure(log.LevelInfo, "Tenant %d RoleMembersWithGrants", tenant.ID)() + if tenantRoles, err := TenantRoles(tx, tenant, roleTemplateIDs...); err != nil { return nil, err } else { @@ -569,6 +573,7 @@ func createAZMGAppRoleAssignmentReadWriteAllEdges(ctx context.Context, db graph. } } } + return nil }) } diff --git a/packages/go/analysis/azure/queries.go b/packages/go/analysis/azure/queries.go index 9d4ee15034..016b86db9c 100644 --- a/packages/go/analysis/azure/queries.go +++ b/packages/go/analysis/azure/queries.go @@ -54,6 +54,8 @@ func GetCollectedTenants(ctx context.Context, db graph.Database) (graph.NodeSet, } func FetchGraphDBTierZeroTaggedAssets(tx graph.Transaction, tenant *graph.Node) (graph.NodeSet, error) { + defer log.LogAndMeasure(log.LevelInfo, "Tenant %d FetchGraphDBTierZeroTaggedAssets", tenant.ID)() + if tenantObjectID, err := tenant.Properties.Get(common.ObjectID.String()).String(); err != nil { log.Errorf("Tenant node %d does not have a valid %s property: %v", tenant.ID, common.ObjectID, err) return nil, err @@ -73,8 +75,7 @@ func FetchGraphDBTierZeroTaggedAssets(tx graph.Transaction, tenant *graph.Node) } func FetchAzureAttackPathRoots(tx graph.Transaction, tenant *graph.Node) (graph.NodeSet, error) { - log.Infof("Fetching tier zero nodes for tenant %d", tenant.ID) - defer log.Measure(log.LevelInfo, "Finished fetching tier zero nodes for tenant %d", tenant.ID)() + defer log.LogAndMeasure(log.LevelDebug, "Tenant %d FetchAzureAttackPathRoots", tenant.ID)() attackPathRoots := graph.NewNodeKindSet() @@ -232,8 +233,6 @@ func FetchAzureAttackPathRoots(tx graph.Transaction, tenant *graph.Node) (graph. } } - log.Infof("Collapsed an additional %d nodes into tier zero for non-descent relationships", inboundNodes.Len()) - tierZeroNodes.AddSet(inboundNodes) return tierZeroNodes, nil } diff --git a/packages/go/analysis/go.mod b/packages/go/analysis/go.mod index c6ab3c3ee7..07f6010adc 100644 --- a/packages/go/analysis/go.mod +++ b/packages/go/analysis/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/analysis -go 1.20 +go 1.21 require ( github.com/RoaringBitmap/roaring v1.3.0 diff --git a/packages/go/cache/cache.go b/packages/go/cache/cache.go index 0f4c984d22..0459e766af 100644 --- a/packages/go/cache/cache.go +++ b/packages/go/cache/cache.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package cache diff --git a/packages/go/cache/go.mod b/packages/go/cache/go.mod index 59f76a0bd2..18f5b42f51 100644 --- a/packages/go/cache/go.mod +++ b/packages/go/cache/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/cache -go 1.20 +go 1.21 require ( github.com/hashicorp/golang-lru v0.6.0 diff --git a/packages/go/conftool/go.mod b/packages/go/conftool/go.mod index 13e74c7221..a881849047 100644 --- a/packages/go/conftool/go.mod +++ b/packages/go/conftool/go.mod @@ -16,4 +16,4 @@ module github.com/specterops/bloodhound/conftool -go 1.20 +go 1.21 diff --git a/packages/go/crypto/go.mod b/packages/go/crypto/go.mod index 77e56dd8b1..2f9f474bba 100644 --- a/packages/go/crypto/go.mod +++ b/packages/go/crypto/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/crypto -go 1.20 +go 1.21 require ( go.uber.org/mock v0.2.0 diff --git a/packages/go/cypher/analyzer/analyzer.go b/packages/go/cypher/analyzer/analyzer.go index 9da373b29c..3ce2c29883 100644 --- a/packages/go/cypher/analyzer/analyzer.go +++ b/packages/go/cypher/analyzer/analyzer.go @@ -1,48 +1,66 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package analyzer import ( + "errors" + "fmt" "github.com/specterops/bloodhound/cypher/model" "github.com/specterops/bloodhound/dawgs/graph" ) type Analyzer struct { - handlers []func(node any) + handlers []func(stack *model.WalkStack, node model.Expression) error } -func (s *Analyzer) Analyze(query *model.RegularQuery) error { - return model.Walk( - query, func(parent, node any) error { - for _, handler := range s.handlers { - handler(node) - } +func (s *Analyzer) walkFunc(stack *model.WalkStack, expression model.Expression) error { + var errs []error - return nil - }, - nil, - ) + for _, handler := range s.handlers { + if err := handler(stack, expression); err != nil { + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +func (s *Analyzer) Analyze(query any, extensions ...model.CollectorFunc) error { + return model.Walk(query, model.NewVisitor(s.walkFunc, nil), extensions...) +} + +func Analyze(query any, registrationFunc func(analyzerInst *Analyzer), extensions ...model.CollectorFunc) error { + analyzer := &Analyzer{} + registrationFunc(analyzer) + + return analyzer.Analyze(query, extensions...) } -func WithVisitor[T any](analyzer *Analyzer, visitorFunc func(node T)) { - analyzer.handlers = append(analyzer.handlers, func(node any) { +type TypedVisitor[T model.Expression] func(stack *model.WalkStack, node T) error + +func WithVisitor[T model.Expression](analyzer *Analyzer, visitorFunc TypedVisitor[T]) { + analyzer.handlers = append(analyzer.handlers, func(walkStack *model.WalkStack, node model.Expression) error { if typedNode, typeOK := node.(T); typeOK { - visitorFunc(typedNode) + if err := visitorFunc(walkStack, typedNode); err != nil { + return err + } } + + return nil }) } @@ -62,7 +80,7 @@ type ComplexityMeasure struct { nodeLookupKinds map[string]graph.Kinds } -func (s *ComplexityMeasure) onFunctionInvocation(node *model.FunctionInvocation) { +func (s *ComplexityMeasure) onFunctionInvocation(_ *model.WalkStack, node *model.FunctionInvocation) error { switch node.Name { case "collect": // Collect will force an eager aggregation @@ -72,29 +90,35 @@ func (s *ComplexityMeasure) onFunctionInvocation(node *model.FunctionInvocation) // Calling for a relationship's type is highly likely to be inefficient and should add weight s.Weight += Weight2 } + + return nil } -func (s *ComplexityMeasure) onQuantifier(node *model.Quantifier) { +func (s *ComplexityMeasure) onQuantifier(_ *model.WalkStack, _ *model.Quantifier) error { // Quantifier expressions may increase the size of an inline projection to apply its contained filter and should // be weighted s.Weight += Weight1 + return nil } -func (s *ComplexityMeasure) onFilterExpression(node *model.FilterExpression) { +func (s *ComplexityMeasure) onFilterExpression(_ *model.WalkStack, _ *model.FilterExpression) error { // Filter expressions convert directly into a filter in the query plan which may or may not take advantage // of indexes and should be weighted accordingly s.Weight += Weight1 + return nil } -func (s *ComplexityMeasure) onKindMatcher(node *model.KindMatcher) { +func (s *ComplexityMeasure) onKindMatcher(_ *model.WalkStack, node *model.KindMatcher) error { switch typedReference := node.Reference.(type) { case *model.Variable: // This kind matcher narrows a node reference's kind and will result in an indexed lookup s.nodeLookupKinds[typedReference.Symbol] = s.nodeLookupKinds[typedReference.Symbol].Add(node.Kinds...) } + + return nil } -func (s *ComplexityMeasure) onPatternPart(node *model.PatternPart) { +func (s *ComplexityMeasure) onPatternPart(_ *model.WalkStack, node *model.PatternPart) error { // All pattern parts incur a compounding weight s.numPatterns += 1 s.Weight += s.numPatterns @@ -109,14 +133,17 @@ func (s *ComplexityMeasure) onPatternPart(node *model.PatternPart) { // Rendering all shortest paths could result in a large search s.Weight += Weight2 } + + return nil } -func (s *ComplexityMeasure) onSortItem(node *model.SortItem) { +func (s *ComplexityMeasure) onSortItem(_ *model.WalkStack, _ *model.SortItem) error { // Sorting incurs a weight since it will change how the projection is materialized s.Weight += Weight1 + return nil } -func (s *ComplexityMeasure) onProjection(node *model.Projection) { +func (s *ComplexityMeasure) onProjection(_ *model.WalkStack, node *model.Projection) error { // We want to capture the cost of additional inline projections so ignore the first projection s.Weight += s.numProjections s.numProjections += 1 @@ -125,25 +152,31 @@ func (s *ComplexityMeasure) onProjection(node *model.Projection) { // Distinct incurs a weight since it will change how the projection is materialized s.Weight += Weight1 } + + return nil } -func (s *ComplexityMeasure) onPartialComparison(node *model.PartialComparison) { +func (s *ComplexityMeasure) onPartialComparison(_ *model.WalkStack, node *model.PartialComparison) error { switch node.Operator { case model.OperatorRegexMatch: // Regular expression matching incurs a weight since it can be far more involved than any of the other // string operators s.Weight += Weight1 } + + return nil } -func (s *ComplexityMeasure) onNodePattern(node *model.NodePattern) { - if node.Binding == "" { +func (s *ComplexityMeasure) onNodePattern(_ *model.WalkStack, node *model.NodePattern) error { + if node.Binding == nil { if len(node.Kinds) == 0 { // Unlabeled, unbound nodes will incur a lookup of all nodes in the graph s.Weight += Weight2 } + } else if nodePatternBinding, typeOK := node.Binding.(*model.Variable); !typeOK { + return fmt.Errorf("expected variable for node pattern binding but got: %T", node.Binding) } else { - nodeLookupKinds, hasBinding := s.nodeLookupKinds[node.Binding] + nodeLookupKinds, hasBinding := s.nodeLookupKinds[nodePatternBinding.Symbol] if !hasBinding { nodeLookupKinds = node.Kinds @@ -152,11 +185,13 @@ func (s *ComplexityMeasure) onNodePattern(node *model.NodePattern) { } // Track this node pattern to see if any subsequent expressions will narrow its kind matchers - s.nodeLookupKinds[node.Binding] = nodeLookupKinds + s.nodeLookupKinds[nodePatternBinding.Symbol] = nodeLookupKinds } + + return nil } -func (s *ComplexityMeasure) onRelationshipPattern(node *model.RelationshipPattern) { +func (s *ComplexityMeasure) onRelationshipPattern(_ *model.WalkStack, node *model.RelationshipPattern) error { numKindMatchers := len(node.Kinds) // All relationship lookups incur a weight @@ -191,6 +226,8 @@ func (s *ComplexityMeasure) onRelationshipPattern(node *model.RelationshipPatter s.Weight += Weight1 } } + + return nil } func (s *ComplexityMeasure) onExit() { @@ -210,16 +247,16 @@ func QueryComplexity(query *model.RegularQuery) (*ComplexityMeasure, error) { } ) - WithVisitor[*model.PatternPart](analyzer, measure.onPatternPart) - WithVisitor[*model.NodePattern](analyzer, measure.onNodePattern) - WithVisitor[*model.Projection](analyzer, measure.onProjection) - WithVisitor[*model.RelationshipPattern](analyzer, measure.onRelationshipPattern) - WithVisitor[*model.FunctionInvocation](analyzer, measure.onFunctionInvocation) - WithVisitor[*model.KindMatcher](analyzer, measure.onKindMatcher) - WithVisitor[*model.Quantifier](analyzer, measure.onQuantifier) - WithVisitor[*model.FilterExpression](analyzer, measure.onFilterExpression) - WithVisitor[*model.SortItem](analyzer, measure.onSortItem) - WithVisitor[*model.PartialComparison](analyzer, measure.onPartialComparison) + WithVisitor(analyzer, measure.onPatternPart) + WithVisitor(analyzer, measure.onNodePattern) + WithVisitor(analyzer, measure.onProjection) + WithVisitor(analyzer, measure.onRelationshipPattern) + WithVisitor(analyzer, measure.onFunctionInvocation) + WithVisitor(analyzer, measure.onKindMatcher) + WithVisitor(analyzer, measure.onQuantifier) + WithVisitor(analyzer, measure.onFilterExpression) + WithVisitor(analyzer, measure.onSortItem) + WithVisitor(analyzer, measure.onPartialComparison) if err := analyzer.Analyze(query); err != nil { return nil, err diff --git a/packages/go/cypher/analyzer/analyzer_test.go b/packages/go/cypher/analyzer/analyzer_test.go index 7f141d67a5..166431df2d 100644 --- a/packages/go/cypher/analyzer/analyzer_test.go +++ b/packages/go/cypher/analyzer/analyzer_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package analyzer_test @@ -19,10 +19,10 @@ package analyzer_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/cypher/analyzer" "github.com/specterops/bloodhound/cypher/frontend" "github.com/specterops/bloodhound/cypher/test" + "github.com/stretchr/testify/require" ) func TestQueryComplexity(t *testing.T) { @@ -30,7 +30,7 @@ func TestQueryComplexity(t *testing.T) { for _, testCase := range test.LoadFixture(t, test.PositiveTestCases).RunnableCases() { t.Run(testCase.Name, func(t *testing.T) { // Only bother with the string match tests - if testCase.Type == test.TestTypeStringMatch { + if testCase.Type == test.TypeStringMatch { var ( details = test.UnmarshallTestCaseDetails[test.StringMatchTest](t, testCase) parseContext = frontend.NewContext() diff --git a/packages/go/cypher/frontend/format.go b/packages/go/cypher/backend/cypher/format.go similarity index 71% rename from packages/go/cypher/frontend/format.go rename to packages/go/cypher/backend/cypher/format.go index 2f71484614..21438c4b51 100644 --- a/packages/go/cypher/frontend/format.go +++ b/packages/go/cypher/backend/cypher/format.go @@ -14,7 +14,7 @@ // // SPDX-License-Identifier: Apache-2.0 -package frontend +package cypher import ( "fmt" @@ -44,28 +44,25 @@ func writeJoinedKinds(output io.Writer, delimiter string, kinds graph.Kinds) err return nil } -type Emitter interface { - Write(query *model.RegularQuery, writer io.Writer) error - WriteExpression(output io.Writer, expression model.Expression) error -} - -type CypherEmitter struct { +type Emitter struct { StripLiterals bool } func NewCypherEmitter(stripLiterals bool) Emitter { - return CypherEmitter{ + return Emitter{ StripLiterals: stripLiterals, } } -func (s CypherEmitter) formatNodePattern(output io.Writer, nodePattern *model.NodePattern) error { +func (s Emitter) formatNodePattern(output io.Writer, nodePattern *model.NodePattern) error { if _, err := io.WriteString(output, "("); err != nil { return err } - if _, err := io.WriteString(output, nodePattern.Binding); err != nil { - return err + if nodePattern.Binding != nil { + if err := s.WriteExpression(output, nodePattern.Binding); err != nil { + return err + } } if len(nodePattern.Kinds) > 0 { @@ -95,7 +92,7 @@ func (s CypherEmitter) formatNodePattern(output io.Writer, nodePattern *model.No return nil } -func (s CypherEmitter) formatRelationshipPattern(output io.Writer, relationshipPattern *model.RelationshipPattern) error { +func (s Emitter) formatRelationshipPattern(output io.Writer, relationshipPattern *model.RelationshipPattern) error { switch relationshipPattern.Direction { case graph.DirectionOutbound: if _, err := io.WriteString(output, "-["); err != nil { @@ -111,8 +108,10 @@ func (s CypherEmitter) formatRelationshipPattern(output io.Writer, relationshipP } } - if _, err := io.WriteString(output, relationshipPattern.Binding); err != nil { - return err + if relationshipPattern.Binding != nil { + if err := s.WriteExpression(output, relationshipPattern.Binding); err != nil { + return err + } } if len(relationshipPattern.Kinds) > 0 { @@ -179,9 +178,34 @@ func (s CypherEmitter) formatRelationshipPattern(output io.Writer, relationshipP return nil } -func (s CypherEmitter) formatPatternPart(output io.Writer, patternPart *model.PatternPart) error { - if patternPart.Binding != "" { - if _, err := io.WriteString(output, patternPart.Binding); err != nil { +func (s Emitter) formatPatternElements(output io.Writer, patternElements []*model.PatternElement) error { + for idx, patternElement := range patternElements { + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + // If this is another node pattern then output a delimiter + if idx >= 1 && patternElements[idx-1].IsNodePattern() { + if _, err := io.WriteString(output, ", "); err != nil { + return err + } + } + + if err := s.formatNodePattern(output, nodePattern); err != nil { + return err + } + } else if relationshipPattern, isRelationshipPattern := patternElement.AsRelationshipPattern(); isRelationshipPattern { + if err := s.formatRelationshipPattern(output, relationshipPattern); err != nil { + return err + } + } else { + return fmt.Errorf("invalid pattern element: %T(%+v)", patternElement, patternElement) + } + } + + return nil +} + +func (s Emitter) formatPatternPart(output io.Writer, patternPart *model.PatternPart) error { + if patternPart.Binding != nil { + if err := s.WriteExpression(output, patternPart.Binding); err != nil { return err } @@ -202,25 +226,8 @@ func (s CypherEmitter) formatPatternPart(output io.Writer, patternPart *model.Pa } } - for idx, patternElement := range patternPart.PatternElements { - if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { - // If this is another node pattern then output a delimiter - if idx >= 1 && patternPart.PatternElements[idx-1].IsNodePattern() { - if _, err := io.WriteString(output, ", "); err != nil { - return err - } - } - - if err := s.formatNodePattern(output, nodePattern); err != nil { - return err - } - } else if relationshipPattern, isRelationshipPattern := patternElement.AsRelationshipPattern(); isRelationshipPattern { - if err := s.formatRelationshipPattern(output, relationshipPattern); err != nil { - return err - } - } else { - return fmt.Errorf("invalid pattern element: %T(%+v)", patternElement, patternElement) - } + if err := s.formatPatternElements(output, patternPart.PatternElements); err != nil { + return err } if patternPart.ShortestPathPattern || patternPart.AllShortestPathsPattern { @@ -232,7 +239,7 @@ func (s CypherEmitter) formatPatternPart(output io.Writer, patternPart *model.Pa return nil } -func (s CypherEmitter) formatProjection(output io.Writer, projection *model.Projection) error { +func (s Emitter) formatProjection(output io.Writer, projection *model.Projection) error { if projection.Distinct { if _, err := io.WriteString(output, "distinct "); err != nil { return err @@ -246,18 +253,9 @@ func (s CypherEmitter) formatProjection(output io.Writer, projection *model.Proj } } - if err := s.WriteExpression(output, projectionItem.Expression); err != nil { + if err := s.WriteExpression(output, projectionItem); err != nil { return err } - - if projectionItem.Binding != nil { - if _, err := io.WriteString(output, " as "); err != nil { - return err - } - if _, err := io.WriteString(output, projectionItem.Binding.Symbol); err != nil { - return err - } - } } if projection.Order != nil { @@ -311,7 +309,7 @@ func (s CypherEmitter) formatProjection(output io.Writer, projection *model.Proj return nil } -func (s CypherEmitter) formatReturn(output io.Writer, returnClause *model.Return) error { +func (s Emitter) formatReturn(output io.Writer, returnClause *model.Return) error { if _, err := io.WriteString(output, " return "); err != nil { return err } @@ -323,7 +321,7 @@ func (s CypherEmitter) formatReturn(output io.Writer, returnClause *model.Return return nil } -func (s CypherEmitter) formatWhere(output io.Writer, whereClause *model.Where) error { +func (s Emitter) formatWhere(output io.Writer, whereClause *model.Where) error { if len(whereClause.Expressions) > 0 { if _, err := io.WriteString(output, " where "); err != nil { return err @@ -339,7 +337,7 @@ func (s CypherEmitter) formatWhere(output io.Writer, whereClause *model.Where) e return nil } -func (s CypherEmitter) formatMapLiteral(output io.Writer, mapLiteral model.MapLiteral) error { +func (s Emitter) formatMapLiteral(output io.Writer, mapLiteral model.MapLiteral) error { if _, err := io.WriteString(output, "{"); err != nil { return err } @@ -374,7 +372,7 @@ func (s CypherEmitter) formatMapLiteral(output io.Writer, mapLiteral model.MapLi return nil } -func (s CypherEmitter) formatLiteral(output io.Writer, literal *model.Literal) error { +func (s Emitter) formatLiteral(output io.Writer, literal *model.Literal) error { const literalNullToken = "null" // Check for a null literal first @@ -490,93 +488,111 @@ func (s CypherEmitter) formatLiteral(output io.Writer, literal *model.Literal) e return nil } -func (s CypherEmitter) WriteExpression(output io.Writer, expression model.Expression) error { +func (s Emitter) WriteExpression(writer io.Writer, expression model.Expression) error { switch typedExpression := expression.(type) { + case *model.ProjectionItem: + if err := s.WriteExpression(writer, typedExpression.Expression); err != nil { + return err + } + + if typedExpression.Binding != nil { + if _, err := io.WriteString(writer, " as "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedExpression.Binding); err != nil { + return err + } + } + case *model.Negation: - if _, err := io.WriteString(output, "not "); err != nil { + if _, err := io.WriteString(writer, "not "); err != nil { return err } switch innerExpression := typedExpression.Expression.(type) { case *model.Parenthetical: - if err := s.WriteExpression(output, innerExpression); err != nil { + if err := s.WriteExpression(writer, innerExpression); err != nil { return err } + default: - if _, err := io.WriteString(output, "("); err != nil { + if _, err := io.WriteString(writer, "("); err != nil { return err } - if err := s.WriteExpression(output, innerExpression); err != nil { + + if err := s.WriteExpression(writer, innerExpression); err != nil { return err } - if _, err := io.WriteString(output, ")"); err != nil { + + if _, err := io.WriteString(writer, ")"); err != nil { return err } } case *model.IDInCollection: - if err := s.WriteExpression(output, typedExpression.Variable); err != nil { + if err := s.WriteExpression(writer, typedExpression.Variable); err != nil { return err } - if _, err := io.WriteString(output, " in "); err != nil { + if _, err := io.WriteString(writer, " in "); err != nil { return err } - if err := s.WriteExpression(output, typedExpression.Expression); err != nil { + if err := s.WriteExpression(writer, typedExpression.Expression); err != nil { return err } case *model.FilterExpression: - if err := s.WriteExpression(output, typedExpression.Specifier); err != nil { + if err := s.WriteExpression(writer, typedExpression.Specifier); err != nil { return err } - if typedExpression.Where != nil { - if err := s.formatWhere(output, typedExpression.Where); err != nil { + if typedExpression.Where != nil && len(typedExpression.Where.Expressions) > 0 { + if err := s.formatWhere(writer, typedExpression.Where); err != nil { return err } } case *model.Quantifier: - if _, err := io.WriteString(output, typedExpression.Type.String()); err != nil { + if _, err := io.WriteString(writer, typedExpression.Type.String()); err != nil { return err } - if _, err := io.WriteString(output, "("); err != nil { + if _, err := io.WriteString(writer, "("); err != nil { return err } - if err := s.WriteExpression(output, typedExpression.Filter); err != nil { + if err := s.WriteExpression(writer, typedExpression.Filter); err != nil { return err } - if _, err := io.WriteString(output, ")"); err != nil { + if _, err := io.WriteString(writer, ")"); err != nil { return err } case *model.Parenthetical: - if _, err := io.WriteString(output, "("); err != nil { + if _, err := io.WriteString(writer, "("); err != nil { return err } - if err := s.WriteExpression(output, typedExpression.Expression); err != nil { + if err := s.WriteExpression(writer, typedExpression.Expression); err != nil { return err } - if _, err := io.WriteString(output, ")"); err != nil { + if _, err := io.WriteString(writer, ")"); err != nil { return err } case *model.Disjunction: for idx, joinedExpression := range typedExpression.Expressions { if idx > 0 { - if _, err := io.WriteString(output, " or "); err != nil { + if _, err := io.WriteString(writer, " or "); err != nil { return err } } - if err := s.WriteExpression(output, joinedExpression); err != nil { + if err := s.WriteExpression(writer, joinedExpression); err != nil { return err } } @@ -584,12 +600,12 @@ func (s CypherEmitter) WriteExpression(output io.Writer, expression model.Expres case *model.ExclusiveDisjunction: for idx, joinedExpression := range typedExpression.Expressions { if idx > 0 { - if _, err := io.WriteString(output, " xor "); err != nil { + if _, err := io.WriteString(writer, " xor "); err != nil { return err } } - if err := s.WriteExpression(output, joinedExpression); err != nil { + if err := s.WriteExpression(writer, joinedExpression); err != nil { return err } } @@ -597,200 +613,200 @@ func (s CypherEmitter) WriteExpression(output io.Writer, expression model.Expres case *model.Conjunction: for idx, joinedExpression := range typedExpression.Expressions { if idx > 0 { - if _, err := io.WriteString(output, " and "); err != nil { + if _, err := io.WriteString(writer, " and "); err != nil { return err } } - if err := s.WriteExpression(output, joinedExpression); err != nil { + if err := s.WriteExpression(writer, joinedExpression); err != nil { return err } } - case *model.PartialComparison: - if _, err := io.WriteString(output, " "); err != nil { + case *model.Comparison: + if err := s.WriteExpression(writer, typedExpression.Left); err != nil { return err } - if _, err := io.WriteString(output, typedExpression.Operator.String()); err != nil { - return err + for _, nextPart := range typedExpression.Partials { + if err := s.WriteExpression(writer, nextPart); err != nil { + return err + } } - if _, err := io.WriteString(output, " "); err != nil { + case *model.PartialComparison: + if _, err := io.WriteString(writer, " "); err != nil { return err } - if err := s.WriteExpression(output, typedExpression.Right); err != nil { + if _, err := io.WriteString(writer, typedExpression.Operator.String()); err != nil { return err } - case *model.Comparison: - if err := s.WriteExpression(output, typedExpression.Left); err != nil { + if _, err := io.WriteString(writer, " "); err != nil { return err } - for _, nextPart := range typedExpression.Partials { - if err := s.WriteExpression(output, nextPart); err != nil { - return err - } + if err := s.WriteExpression(writer, typedExpression.Right); err != nil { + return err } case *model.Properties: if typedExpression.Map != nil { - if err := s.formatMapLiteral(output, typedExpression.Map); err != nil { + if err := s.formatMapLiteral(writer, typedExpression.Map); err != nil { return err } - } else if err := s.WriteExpression(output, typedExpression.Parameter); err != nil { + } else if err := s.WriteExpression(writer, typedExpression.Parameter); err != nil { return err } case *model.Variable: - if _, err := io.WriteString(output, typedExpression.Symbol); err != nil { + if _, err := io.WriteString(writer, typedExpression.Symbol); err != nil { return err } case *model.Parameter: - if _, err := io.WriteString(output, "$"); err != nil { + if _, err := io.WriteString(writer, "$"); err != nil { return err } - if _, err := io.WriteString(output, typedExpression.Symbol); err != nil { + if _, err := io.WriteString(writer, typedExpression.Symbol); err != nil { return err } case *model.PropertyLookup: - if err := s.WriteExpression(output, typedExpression.Atom); err != nil { + if err := s.WriteExpression(writer, typedExpression.Atom); err != nil { return err } - if _, err := io.WriteString(output, "."); err != nil { + if _, err := io.WriteString(writer, "."); err != nil { return err } - if _, err := io.WriteString(output, strings.Join(typedExpression.Symbols, ".")); err != nil { + if _, err := io.WriteString(writer, strings.Join(typedExpression.Symbols, ".")); err != nil { return err } case *model.FunctionInvocation: - if _, err := io.WriteString(output, strings.Join(typedExpression.Namespace, ".")); err != nil { + if _, err := io.WriteString(writer, strings.Join(typedExpression.Namespace, ".")); err != nil { return err } - if _, err := io.WriteString(output, typedExpression.Name); err != nil { + if _, err := io.WriteString(writer, typedExpression.Name); err != nil { return err } - if _, err := io.WriteString(output, "("); err != nil { + if _, err := io.WriteString(writer, "("); err != nil { return err } if typedExpression.Distinct { - if _, err := io.WriteString(output, "distinct"); err != nil { + if _, err := io.WriteString(writer, "distinct "); err != nil { return err } } for idx, subExpression := range typedExpression.Arguments { if idx > 0 { - if _, err := io.WriteString(output, ", "); err != nil { + if _, err := io.WriteString(writer, ", "); err != nil { return err } } - if err := s.WriteExpression(output, subExpression); err != nil { + if err := s.WriteExpression(writer, subExpression); err != nil { return err } } - if _, err := io.WriteString(output, ")"); err != nil { + if _, err := io.WriteString(writer, ")"); err != nil { return err } case graph.Kind: - if _, err := io.WriteString(output, ":"); err != nil { + if _, err := io.WriteString(writer, ":"); err != nil { return err } - if _, err := io.WriteString(output, typedExpression.String()); err != nil { + if _, err := io.WriteString(writer, typedExpression.String()); err != nil { return err } case graph.Kinds: - if _, err := io.WriteString(output, ":"); err != nil { + if _, err := io.WriteString(writer, ":"); err != nil { return err } - if err := writeJoinedKinds(output, ":", typedExpression); err != nil { + if err := writeJoinedKinds(writer, ":", typedExpression); err != nil { return err } case *model.KindMatcher: - if err := s.WriteExpression(output, typedExpression.Reference); err != nil { + if err := s.WriteExpression(writer, typedExpression.Reference); err != nil { return err } for _, matcher := range typedExpression.Kinds { - if _, err := io.WriteString(output, ":"); err != nil { + if _, err := io.WriteString(writer, ":"); err != nil { return err } - if _, err := io.WriteString(output, matcher.String()); err != nil { + if _, err := io.WriteString(writer, matcher.String()); err != nil { return err } } case *model.RangeQuantifier: - if _, err := io.WriteString(output, typedExpression.Value); err != nil { + if _, err := io.WriteString(writer, typedExpression.Value); err != nil { return err } case model.Operator: - if _, err := io.WriteString(output, typedExpression.String()); err != nil { + if _, err := io.WriteString(writer, typedExpression.String()); err != nil { return err } case *model.Skip: - return s.WriteExpression(output, typedExpression.Value) + return s.WriteExpression(writer, typedExpression.Value) case *model.Limit: - return s.WriteExpression(output, typedExpression.Value) + return s.WriteExpression(writer, typedExpression.Value) case *model.Literal: if !s.StripLiterals { - return s.formatLiteral(output, typedExpression) + return s.formatLiteral(writer, typedExpression) } else { - _, err := io.WriteString(output, strippedLiteral) + _, err := io.WriteString(writer, strippedLiteral) return err } - case []*model.PatternPart: - return s.formatPattern(output, typedExpression) + case *model.PatternPredicate: + return s.formatPatternElements(writer, typedExpression.PatternElements) case *model.ArithmeticExpression: - if err := s.WriteExpression(output, typedExpression.Left); err != nil { + if err := s.WriteExpression(writer, typedExpression.Left); err != nil { return err } for _, part := range typedExpression.Partials { - if err := s.WriteExpression(output, part); err != nil { + if err := s.WriteExpression(writer, part); err != nil { return err } } case *model.PartialArithmeticExpression: - if _, err := io.WriteString(output, " "); err != nil { + if _, err := io.WriteString(writer, " "); err != nil { return err } - if _, err := io.WriteString(output, typedExpression.Operator.String()); err != nil { + if _, err := io.WriteString(writer, typedExpression.Operator.String()); err != nil { return err } - if _, err := io.WriteString(output, " "); err != nil { + if _, err := io.WriteString(writer, " "); err != nil { return err } - return s.WriteExpression(output, typedExpression.Right) + return s.WriteExpression(writer, typedExpression.Right) default: return fmt.Errorf("unexpected expression type for string formatting: %T", expression) @@ -799,7 +815,7 @@ func (s CypherEmitter) WriteExpression(output io.Writer, expression model.Expres return nil } -func (s CypherEmitter) formatRemove(output io.Writer, remove *model.Remove) error { +func (s Emitter) formatRemove(output io.Writer, remove *model.Remove) error { if _, err := io.WriteString(output, "remove "); err != nil { return err } @@ -829,7 +845,7 @@ func (s CypherEmitter) formatRemove(output io.Writer, remove *model.Remove) erro return nil } -func (s CypherEmitter) formatSet(output io.Writer, set *model.Set) error { +func (s Emitter) formatSet(output io.Writer, set *model.Set) error { if _, err := io.WriteString(output, "set "); err != nil { return err } @@ -869,7 +885,7 @@ func (s CypherEmitter) formatSet(output io.Writer, set *model.Set) error { return nil } -func (s CypherEmitter) formatDelete(output io.Writer, delete *model.Delete) error { +func (s Emitter) formatDelete(output io.Writer, delete *model.Delete) error { if delete.Detach { if _, err := io.WriteString(output, "detach delete "); err != nil { return err @@ -893,7 +909,7 @@ func (s CypherEmitter) formatDelete(output io.Writer, delete *model.Delete) erro return nil } -func (s CypherEmitter) formatPattern(output io.Writer, pattern []*model.PatternPart) error { +func (s Emitter) formatPattern(output io.Writer, pattern []*model.PatternPart) error { for idx, patternPart := range pattern { if idx > 0 { if _, err := io.WriteString(output, ", "); err != nil { @@ -909,7 +925,7 @@ func (s CypherEmitter) formatPattern(output io.Writer, pattern []*model.PatternP return nil } -func (s CypherEmitter) formatCreate(output io.Writer, create *model.Create) error { +func (s Emitter) formatCreate(output io.Writer, create *model.Create) error { if _, err := io.WriteString(output, "create "); err != nil { return err } @@ -917,7 +933,7 @@ func (s CypherEmitter) formatCreate(output io.Writer, create *model.Create) erro return s.formatPattern(output, create.Pattern) } -func (s CypherEmitter) formatUpdatingClause(output io.Writer, updatingClause *model.UpdatingClause) error { +func (s Emitter) formatUpdatingClause(output io.Writer, updatingClause *model.UpdatingClause) error { switch typedClause := updatingClause.Clause.(type) { case *model.Create: return s.formatCreate(output, typedClause) @@ -936,7 +952,7 @@ func (s CypherEmitter) formatUpdatingClause(output io.Writer, updatingClause *mo } } -func (s CypherEmitter) formatReadingClause(output io.Writer, readingClause *model.ReadingClause) error { +func (s Emitter) formatReadingClause(output io.Writer, readingClause *model.ReadingClause) error { if readingClause.Match != nil { if readingClause.Match.Optional { if _, err := io.WriteString(output, "optional "); err != nil { @@ -960,7 +976,7 @@ func (s CypherEmitter) formatReadingClause(output io.Writer, readingClause *mode } } - if readingClause.Match.Where != nil { + if readingClause.Match.Where != nil && len(readingClause.Match.Where.Expressions) > 0 { if err := s.formatWhere(output, readingClause.Match.Where); err != nil { return err } @@ -988,47 +1004,49 @@ func (s CypherEmitter) formatReadingClause(output io.Writer, readingClause *mode return nil } -func (s CypherEmitter) formatSinglePartQuery(output io.Writer, singlePartQuery *model.SinglePartQuery) error { +func (s Emitter) formatSinglePartQuery(writer io.Writer, singlePartQuery *model.SinglePartQuery) error { for idx, readingClause := range singlePartQuery.ReadingClauses { if idx > 0 { - if _, err := io.WriteString(output, " "); err != nil { + if _, err := io.WriteString(writer, " "); err != nil { return err } } - if err := s.formatReadingClause(output, readingClause); err != nil { + if err := s.formatReadingClause(writer, readingClause); err != nil { return err } } if len(singlePartQuery.UpdatingClauses) > 0 { if len(singlePartQuery.ReadingClauses) > 0 { - if _, err := io.WriteString(output, " "); err != nil { + if _, err := io.WriteString(writer, " "); err != nil { return err } } for idx, updatingClause := range singlePartQuery.UpdatingClauses { if idx > 0 { - if _, err := io.WriteString(output, " "); err != nil { + if _, err := io.WriteString(writer, " "); err != nil { return err } } - if err := s.formatUpdatingClause(output, updatingClause); err != nil { + if typedUpdatingClause, typeOK := updatingClause.(*model.UpdatingClause); !typeOK { + return fmt.Errorf("unexpected updating clause type %T", updatingClause) + } else if err := s.formatUpdatingClause(writer, typedUpdatingClause); err != nil { return err } } } if singlePartQuery.Return != nil { - return s.formatReturn(output, singlePartQuery.Return) + return s.formatReturn(writer, singlePartQuery.Return) } return nil } -func (s CypherEmitter) formatWith(output io.Writer, with *model.With) error { +func (s Emitter) formatWith(output io.Writer, with *model.With) error { if _, err := io.WriteString(output, "with "); err != nil { return err } @@ -1037,7 +1055,7 @@ func (s CypherEmitter) formatWith(output io.Writer, with *model.With) error { return err } - if with.Where != nil { + if with.Where != nil && len(with.Where.Expressions) > 0 { if err := s.formatWhere(output, with.Where); err != nil { return err } @@ -1046,7 +1064,7 @@ func (s CypherEmitter) formatWith(output io.Writer, with *model.With) error { return nil } -func (s CypherEmitter) formatMultiPartQuery(output io.Writer, multiPartQuery *model.MultiPartQuery) error { +func (s Emitter) formatMultiPartQuery(output io.Writer, multiPartQuery *model.MultiPartQuery) error { for idx, multiPartQueryPart := range multiPartQuery.Parts { var ( numReadingClauses = len(multiPartQueryPart.ReadingClauses) @@ -1117,7 +1135,7 @@ func (s CypherEmitter) formatMultiPartQuery(output io.Writer, multiPartQuery *mo return nil } -func (s CypherEmitter) Write(regularQuery *model.RegularQuery, writer io.Writer) error { +func (s Emitter) Write(regularQuery *model.RegularQuery, writer io.Writer) error { if regularQuery.SingleQuery != nil { if regularQuery.SingleQuery.MultiPartQuery != nil { if err := s.formatMultiPartQuery(writer, regularQuery.SingleQuery.MultiPartQuery); err != nil { diff --git a/packages/go/cypher/frontend/format_test.go b/packages/go/cypher/backend/cypher/format_test.go similarity index 63% rename from packages/go/cypher/frontend/format_test.go rename to packages/go/cypher/backend/cypher/format_test.go index 90be19c1a4..9024153908 100644 --- a/packages/go/cypher/frontend/format_test.go +++ b/packages/go/cypher/backend/cypher/format_test.go @@ -14,19 +14,23 @@ // // SPDX-License-Identifier: Apache-2.0 -package frontend +package cypher_test import ( "bytes" + "github.com/specterops/bloodhound/cypher/backend/cypher" + "github.com/specterops/bloodhound/cypher/frontend" "github.com/stretchr/testify/require" "testing" + + "github.com/specterops/bloodhound/cypher/test" ) func TestCypherEmitter_StripLiterals(t *testing.T) { var ( buffer = &bytes.Buffer{} - regularQuery, err = ParseCypher(DefaultCypherContext(), "match (n {value: 'PII'}) where n.other = 'more pii' and n.number = 411 return n.name, n") - emitter = CypherEmitter{ + regularQuery, err = frontend.ParseCypher(frontend.DefaultCypherContext(), "match (n {value: 'PII'}) where n.other = 'more pii' and n.number = 411 return n.name, n") + emitter = cypher.Emitter{ StripLiterals: true, } ) @@ -35,3 +39,11 @@ func TestCypherEmitter_StripLiterals(t *testing.T) { require.Nil(t, emitter.Write(regularQuery, buffer)) require.Equal(t, "match (n {value: $STRIPPED}) where n.other = $STRIPPED and n.number = $STRIPPED return n.name, n", buffer.String()) } + +func TestCypherEmitter_HappyPath(t *testing.T) { + test.LoadFixture(t, test.PositiveTestCases).Run(t) +} + +func TestCypherEmitter_NegativeCases(t *testing.T) { + test.LoadFixture(t, test.NegativeTestCases).Run(t) +} diff --git a/packages/go/cypher/backend/gen.go b/packages/go/cypher/backend/gen.go new file mode 100644 index 0000000000..02e1f02549 --- /dev/null +++ b/packages/go/cypher/backend/gen.go @@ -0,0 +1,49 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package backend + +import ( + "bytes" + "github.com/specterops/bloodhound/cypher/backend/cypher" + "github.com/specterops/bloodhound/cypher/frontend" + "github.com/specterops/bloodhound/cypher/model" + "io" +) + +type Emitter interface { + Write(query *model.RegularQuery, writer io.Writer) error + WriteExpression(output io.Writer, expression model.Expression) error +} + +func CypherToCypher(ctx *frontend.Context, input string) (string, error) { + if query, err := frontend.ParseCypher(ctx, input); err != nil { + return "", err + } else { + var ( + output = &bytes.Buffer{} + emitter = cypher.Emitter{ + StripLiterals: false, + } + ) + + if err := emitter.Write(query, output); err != nil { + return "", err + } + + return output.String(), nil + } +} diff --git a/packages/go/cypher/backend/pgsql/facts.go b/packages/go/cypher/backend/pgsql/facts.go new file mode 100644 index 0000000000..6546b494a8 --- /dev/null +++ b/packages/go/cypher/backend/pgsql/facts.go @@ -0,0 +1,35 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql + +const ( + cypherCountFunction = "count" + cypherDateFunction = "date" + cypherTimeFunction = "time" + cypherLocalTimeFunction = "localtime" + cypherDateTimeFunction = "datetime" + cypherLocalDateTimeFunction = "localdatetime" + cypherDurationFunction = "duration" + cypherIdentityFunction = "id" + cypherToLowerFunction = "toLower" + cypherNodeLabelsFunction = "labels" + cypherEdgeTypeFunction = "type" + + pgsqlAnyFunction = "any" + pgsqlToJSONBFunction = "to_jsonb" + pgsqlToLowerFunction = "lower" +) diff --git a/packages/go/cypher/backend/pgsql/format.go b/packages/go/cypher/backend/pgsql/format.go new file mode 100644 index 0000000000..3ca574f9d7 --- /dev/null +++ b/packages/go/cypher/backend/pgsql/format.go @@ -0,0 +1,1318 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql + +import ( + "fmt" + "github.com/specterops/bloodhound/cypher/model" + pgModel "github.com/specterops/bloodhound/cypher/model/pg" + pgDriverModel "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + "github.com/specterops/bloodhound/dawgs/graph" + "io" + "strconv" +) + +const strippedLiteral = "$STRIPPED" + +type KindMapper interface { + MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) +} + +type Emitter struct { + StripLiterals bool + kindMapper KindMapper +} + +func NewEmitter(stripLiterals bool, kindMapper KindMapper) *Emitter { + return &Emitter{ + StripLiterals: stripLiterals, + kindMapper: kindMapper, + } +} + +func (s *Emitter) formatMapLiteral(output io.Writer, mapLiteral model.MapLiteral) error { + if _, err := io.WriteString(output, "{"); err != nil { + return err + } + + first := true + for key, subExpression := range mapLiteral { + if !first { + if _, err := io.WriteString(output, ", "); err != nil { + return err + } + } else { + first = false + } + + if _, err := io.WriteString(output, key); err != nil { + return err + } + + if _, err := io.WriteString(output, ": "); err != nil { + return err + } + + if err := s.WriteExpression(output, subExpression); err != nil { + return err + } + } + + if _, err := io.WriteString(output, "}"); err != nil { + return err + } + + return nil +} + +func (s *Emitter) formatLiteral(output io.Writer, literal *model.Literal) error { + const literalNullToken = "null" + + // Check for a null literal first + if literal.Null { + if _, err := io.WriteString(output, literalNullToken); err != nil { + return err + } + return nil + } + + // Attempt to string format the literal value + switch typedLiteral := literal.Value.(type) { + case string: + // Note: the cypher AST model expects literal strings to be wrapped in single quote characters (') so no + // additional formatting is done here + if _, err := WriteStrings(output, typedLiteral); err != nil { + return err + } + + case graph.ID: + if _, err := io.WriteString(output, strconv.FormatInt(int64(typedLiteral), 10)); err != nil { + return err + } + + case int8: + if _, err := io.WriteString(output, strconv.FormatInt(int64(typedLiteral), 10)); err != nil { + return err + } + + case int16: + if _, err := io.WriteString(output, strconv.FormatInt(int64(typedLiteral), 10)); err != nil { + return err + } + + case int32: + if _, err := io.WriteString(output, strconv.FormatInt(int64(typedLiteral), 10)); err != nil { + return err + } + + case int64: + if _, err := io.WriteString(output, strconv.FormatInt(typedLiteral, 10)); err != nil { + return err + } + + case int: + if _, err := io.WriteString(output, strconv.FormatInt(int64(typedLiteral), 10)); err != nil { + return err + } + + case uint8: + if _, err := io.WriteString(output, strconv.FormatUint(uint64(typedLiteral), 10)); err != nil { + return err + } + + case uint16: + if _, err := io.WriteString(output, strconv.FormatUint(uint64(typedLiteral), 10)); err != nil { + return err + } + + case uint32: + if _, err := io.WriteString(output, strconv.FormatUint(uint64(typedLiteral), 10)); err != nil { + return err + } + + case uint64: + if _, err := io.WriteString(output, strconv.FormatUint(typedLiteral, 10)); err != nil { + return err + } + + case uint: + if _, err := io.WriteString(output, strconv.FormatUint(uint64(typedLiteral), 10)); err != nil { + return err + } + + case bool: + if _, err := io.WriteString(output, strconv.FormatBool(typedLiteral)); err != nil { + return err + } + + case float32: + if _, err := io.WriteString(output, strconv.FormatFloat(float64(typedLiteral), 'f', -1, 64)); err != nil { + return err + } + + case float64: + if _, err := io.WriteString(output, strconv.FormatFloat(typedLiteral, 'f', -1, 64)); err != nil { + return err + } + + case model.MapLiteral: + if err := s.formatMapLiteral(output, typedLiteral); err != nil { + return err + } + + case *model.ListLiteral: + if _, err := io.WriteString(output, "array["); err != nil { + return err + } + + for idx, subExpression := range *typedLiteral { + if idx > 0 { + if _, err := io.WriteString(output, ", "); err != nil { + return err + } + } + + if err := s.WriteExpression(output, subExpression); err != nil { + return err + } + } + + if _, err := io.WriteString(output, "]"); err != nil { + return err + } + + default: + return fmt.Errorf("unexpected literal type for string formatting: %T", literal.Value) + } + + return nil +} + +func (s *Emitter) writeReturn(writer io.Writer, returnClause *model.Return) error { + if returnClause.Projection.Distinct { + if _, err := WriteStrings(writer, "distinct "); err != nil { + return err + } + } + + for idx, projectionItem := range returnClause.Projection.Items { + if idx > 0 { + if _, err := io.WriteString(writer, ", "); err != nil { + return nil + } + } + + if err := s.WriteExpression(writer, projectionItem); err != nil { + return err + } + } + + return nil +} + +func (s *Emitter) writeWhere(writer io.Writer, whereClause *model.Where) error { + if len(whereClause.Expressions) > 0 { + if _, err := io.WriteString(writer, " where "); err != nil { + return err + } + } + + for _, expression := range whereClause.Expressions { + if err := s.WriteExpression(writer, expression); err != nil { + return err + } + } + + return nil +} + +func (s *Emitter) writePatternElements(writer io.Writer, patternElements []*model.PatternElement) error { + for idx, patternElement := range patternElements { + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + if idx == 0 { + if _, err := io.WriteString(writer, pgDriverModel.NodeTable); err != nil { + return nil + } + + if _, err := io.WriteString(writer, " as "); err != nil { + return nil + } + + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return nil + } + } else { + previousRelationshipPattern, _ := patternElements[idx-1].AsRelationshipPattern() + + if _, err := WriteStrings(writer, " join ", pgDriverModel.NodeTable, " "); err != nil { + return err + } + + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, " on "); err != nil { + return err + } + + switch previousRelationshipPattern.Direction { + case graph.DirectionOutbound: + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousRelationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".end_id"); err != nil { + return err + } + + case graph.DirectionInbound: + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousRelationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".start_id"); err != nil { + return err + } + + default: + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousRelationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".start_id or "); err != nil { + return err + } + + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousRelationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".end_id "); err != nil { + return err + } + } + } + } else { + relationshipPattern, _ := patternElement.AsRelationshipPattern() + + if idx == 0 { + if _, err := io.WriteString(writer, pgDriverModel.EdgeTable); err != nil { + return nil + } + + if _, err := io.WriteString(writer, " as "); err != nil { + return nil + } + + if err := s.WriteExpression(writer, relationshipPattern.Binding); err != nil { + return nil + } + } else { + previousNodePattern, _ := patternElements[idx-1].AsNodePattern() + + if _, err := WriteStrings(writer, " join ", pgDriverModel.EdgeTable, " "); err != nil { + return err + } + + if err := s.WriteExpression(writer, relationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, " on "); err != nil { + return err + } + + switch relationshipPattern.Direction { + case graph.DirectionOutbound: + if err := s.WriteExpression(writer, relationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".start_id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousNodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id"); err != nil { + return err + } + + case graph.DirectionInbound: + if err := s.WriteExpression(writer, relationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".end_id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousNodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id"); err != nil { + return err + } + + case graph.DirectionBoth: + if err := s.WriteExpression(writer, relationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".start_id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousNodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id or "); err != nil { + return err + } + + if err := s.WriteExpression(writer, relationshipPattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".end_id = "); err != nil { + return err + } + + if err := s.WriteExpression(writer, previousNodePattern.Binding); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".id"); err != nil { + return err + } + + default: + return fmt.Errorf("unsupported pattern direction: %s", relationshipPattern.Direction) + } + } + } + } + + return nil +} + +func (s *Emitter) writeMatch(writer io.Writer, matchClause *model.Match) error { + for idx, pattern := range matchClause.Pattern { + if idx > 0 { + if _, err := io.WriteString(writer, ", "); err != nil { + return err + } + } + + if err := s.writePatternElements(writer, pattern.PatternElements); err != nil { + return err + } + } + + if matchClause.Where != nil { + if err := s.writeWhere(writer, matchClause.Where); err != nil { + return err + } + } + + return nil +} + +func (s *Emitter) writeSelect(writer io.Writer, singlePartQuery *model.SinglePartQuery) error { + if _, err := io.WriteString(writer, "select "); err != nil { + return err + } + + if singlePartQuery.Return != nil { + if err := s.writeReturn(writer, singlePartQuery.Return); err != nil { + return err + } + } + + if _, err := io.WriteString(writer, " from "); err != nil { + return err + } + + for _, readingClause := range singlePartQuery.ReadingClauses { + if readingClause.Match != nil { + if err := s.writeMatch(writer, readingClause.Match); err != nil { + return err + } + } + } + + if singlePartQuery.Return != nil { + if order := singlePartQuery.Return.Projection.Order; order != nil { + if _, err := WriteStrings(writer, " order by "); err != nil { + return err + } + + for idx, orderItem := range order.Items { + if idx > 0 { + if _, err := WriteStrings(writer, ", "); err != nil { + return err + } + } + + if err := s.WriteExpression(writer, orderItem.Expression); err != nil { + return err + } + + if orderItem.Ascending { + if _, err := WriteStrings(writer, " asc"); err != nil { + return err + } + } else { + if _, err := WriteStrings(writer, " desc"); err != nil { + return err + } + } + } + } + + if skip := singlePartQuery.Return.Projection.Skip; skip != nil { + if _, err := WriteStrings(writer, " offset "); err != nil { + return err + } + + if err := s.WriteExpression(writer, skip.Value); err != nil { + return err + } + } + + if limit := singlePartQuery.Return.Projection.Limit; limit != nil { + if _, err := WriteStrings(writer, " limit "); err != nil { + return err + } + + if err := s.WriteExpression(writer, limit.Value); err != nil { + return err + } + } + } + + return nil +} + +func (s *Emitter) writeDelete(writer io.Writer, singlePartQuery *model.SinglePartQuery, delete *pgModel.Delete) error { + if delete.NodeDelete { + if _, err := WriteStrings(writer, "delete from ", pgDriverModel.NodeTable, " as ", delete.Binding.Symbol); err != nil { + return err + } + + first := true + + for _, readingClause := range singlePartQuery.ReadingClauses { + if matchClause := readingClause.Match; matchClause != nil { + for _, pattern := range matchClause.Pattern { + for _, patternElement := range pattern.PatternElements { + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + switch typedBinding := nodePattern.Binding.(type) { + case *pgModel.AnnotatedVariable: + if typedBinding.Symbol == delete.Binding.Symbol { + continue + } + } + + if !first { + if _, err := WriteStrings(writer, ", "); err != nil { + return err + } + } else { + if _, err := WriteStrings(writer, " using "); err != nil { + return err + } + + first = false + } + + if _, err := WriteStrings(writer, pgDriverModel.NodeTable, " as "); err != nil { + return err + } + + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return err + } + } else { + relationshipPattern, _ := patternElement.AsRelationshipPattern() + + switch typedBinding := relationshipPattern.Binding.(type) { + case *pgModel.AnnotatedVariable: + if typedBinding.Symbol == delete.Binding.Symbol { + continue + } + } + + if !first { + if _, err := WriteStrings(writer, ", "); err != nil { + return err + } + } else { + if _, err := WriteStrings(writer, " using "); err != nil { + return err + } + + first = false + } + + if _, err := WriteStrings(writer, pgDriverModel.EdgeTable, " as "); err != nil { + return err + } + + if err := s.WriteExpression(writer, relationshipPattern.Binding); err != nil { + return err + } + } + } + } + } + } + } else { + if _, err := WriteStrings(writer, "delete from ", pgDriverModel.EdgeTable, " as ", delete.Binding.Symbol); err != nil { + return err + } + + first := true + + for _, readingClause := range singlePartQuery.ReadingClauses { + if matchClause := readingClause.Match; matchClause != nil { + for _, pattern := range matchClause.Pattern { + for _, patternElement := range pattern.PatternElements { + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + if !first { + if _, err := WriteStrings(writer, ", "); err != nil { + return err + } + } else { + if _, err := WriteStrings(writer, " using "); err != nil { + return err + } + + first = false + } + + if _, err := WriteStrings(writer, pgDriverModel.NodeTable, " as "); err != nil { + return err + } + + if err := s.WriteExpression(writer, nodePattern.Binding); err != nil { + return err + } + } + } + } + } + } + } + + for _, readingClause := range singlePartQuery.ReadingClauses { + if matchClause := readingClause.Match; matchClause != nil { + if matchClause.Where != nil { + if err := s.writeWhere(writer, matchClause.Where); err != nil { + return err + } + } + } + } + + return nil +} + +func (s *Emitter) writeUpdates(writer io.Writer, singlePartQuery *model.SinglePartQuery) error { + if _, err := io.WriteString(writer, "update "); err != nil { + return err + } + + for _, readingClause := range singlePartQuery.ReadingClauses { + if matchClause := readingClause.Match; matchClause != nil { + for idx, pattern := range matchClause.Pattern { + if idx > 0 { + if _, err := io.WriteString(writer, ", "); err != nil { + return err + } + } + + if err := s.writePatternElements(writer, pattern.PatternElements); err != nil { + return err + } + } + } + } + + if _, err := WriteStrings(writer, " set "); err != nil { + return err + } + + for idx, item := range singlePartQuery.UpdatingClauses { + if idx > 0 { + if _, err := WriteStrings(writer, ", "); err != nil { + return err + } + } + + switch typedUpdateItem := item.(type) { + case *pgModel.Delete: + if err := s.writeDelete(writer, singlePartQuery, typedUpdateItem); err != nil { + return err + } + + case *pgModel.PropertyMutation: + // Can't use aliased names in the set clauses of the SQL statement so default to just the raw + // column names + if _, err := WriteStrings(writer, "properties = properties"); err != nil { + return err + } + + if typedUpdateItem.Additions != nil { + if typedUpdateItem.Removals != nil { + if _, err := WriteStrings(writer, " - "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedUpdateItem.Removals); err != nil { + return err + } + + if _, err := WriteStrings(writer, "::text[]"); err != nil { + return err + } + } + + if _, err := WriteStrings(writer, " || "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedUpdateItem.Additions); err != nil { + return err + } + } else if typedUpdateItem.Removals != nil { + if _, err := WriteStrings(writer, " - "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedUpdateItem.Removals); err != nil { + return err + } + + if _, err := WriteStrings(writer, "::text[]"); err != nil { + return err + } + } + + case *pgModel.KindMutation: + // Cypher and therefore this translation does not support kind mutation of relationships + if typedUpdateItem.Variable.Type != pgModel.Node { + return fmt.Errorf("unsupported SQL type for kind mutation: %s", typedUpdateItem.Variable.Type) + } + + // Can't use aliased names in the set clauses of the SQL statement so default to just the raw + // column names + if _, err := WriteStrings(writer, "kind_ids = kind_ids"); err != nil { + return err + } + + if typedUpdateItem.Additions != nil { + if typedUpdateItem.Removals != nil { + if _, err := WriteStrings(writer, " - "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedUpdateItem.Removals); err != nil { + return err + } + } + + if _, err := WriteStrings(writer, " || "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedUpdateItem.Additions); err != nil { + return err + } + } else if typedUpdateItem.Removals != nil { + if _, err := WriteStrings(writer, " - "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedUpdateItem.Removals); err != nil { + return err + } + } + + default: + return fmt.Errorf("unsupported update clause item: %T", item) + } + } + + for _, readingClause := range singlePartQuery.ReadingClauses { + if matchClause := readingClause.Match; matchClause != nil { + if matchClause.Where != nil { + if err := s.writeWhere(writer, matchClause.Where); err != nil { + return err + } + } + } + } + + if singlePartQuery.Return != nil { + if _, err := WriteStrings(writer, " returning "); err != nil { + return err + } + + if err := s.writeReturn(writer, singlePartQuery.Return); err != nil { + return err + } + } + + return nil +} + +func (s *Emitter) writeUpdatingClauses(writer io.Writer, singlePartQuery *model.SinglePartQuery) error { + // Delete statements must be rendered as their own outputs + numDeletes := 0 + + for _, updateClause := range singlePartQuery.UpdatingClauses { + switch typedClause := updateClause.(type) { + case *pgModel.Delete: + numDeletes++ + + if err := s.writeDelete(writer, singlePartQuery, typedClause); err != nil { + return err + } + } + } + + if len(singlePartQuery.UpdatingClauses) > numDeletes { + if err := s.writeUpdates(writer, singlePartQuery); err != nil { + return err + } + } + + return nil +} + +func (s *Emitter) writeSinglePartQuery(writer io.Writer, singlePartQuery *model.SinglePartQuery) error { + if len(singlePartQuery.UpdatingClauses) > 0 { + return s.writeUpdatingClauses(writer, singlePartQuery) + } else { + return s.writeSelect(writer, singlePartQuery) + } +} + +func (s *Emitter) writeSubquery(writer io.Writer, subquery *pgModel.Subquery) error { + if _, err := io.WriteString(writer, "exists(select * from "); err != nil { + return err + } + + if err := s.writePatternElements(writer, subquery.PatternElements); err != nil { + return err + } + + if subquery.Filter != nil { + subQueryWhereClause := model.NewWhere() + subQueryWhereClause.Add(subquery.Filter) + + if err := s.writeWhere(writer, subQueryWhereClause); err != nil { + return err + } + } + + if _, err := io.WriteString(writer, " limit 1)"); err != nil { + return err + } + + return nil +} + +func (s *Emitter) WriteExpression(writer io.Writer, expression model.Expression) error { + switch typedExpression := expression.(type) { + case *pgModel.Subquery: + if err := s.writeSubquery(writer, typedExpression); err != nil { + return err + } + + case *model.Negation: + if _, err := io.WriteString(writer, "not "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedExpression.Expression); err != nil { + return err + } + + case *model.Disjunction: + for idx, joinedExpression := range typedExpression.Expressions { + if idx > 0 { + if _, err := io.WriteString(writer, " or "); err != nil { + return err + } + } + + if err := s.WriteExpression(writer, joinedExpression); err != nil { + return err + } + } + + case *model.Conjunction: + for idx, joinedExpression := range typedExpression.Expressions { + if idx > 0 { + if _, err := io.WriteString(writer, " and "); err != nil { + return err + } + } + + if err := s.WriteExpression(writer, joinedExpression); err != nil { + return err + } + } + + case *model.Comparison: + if err := s.WriteExpression(writer, typedExpression.Left); err != nil { + return err + } + + for _, nextPart := range typedExpression.Partials { + if err := s.WriteExpression(writer, nextPart); err != nil { + return err + } + } + + case *model.PartialComparison: + if _, err := WriteStrings(writer, " ", typedExpression.Operator.String(), " "); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedExpression.Right); err != nil { + return err + } + + case *pgModel.AnnotatedLiteral: + if err := s.WriteExpression(writer, &typedExpression.Literal); err != nil { + return err + } + + case *model.Literal: + if !s.StripLiterals { + return s.formatLiteral(writer, typedExpression) + } else { + _, err := io.WriteString(writer, strippedLiteral) + return err + } + + case *model.Variable: + if _, err := io.WriteString(writer, typedExpression.Symbol); err != nil { + return err + } + + case *pgModel.AnnotatedVariable: + if _, err := io.WriteString(writer, typedExpression.Symbol); err != nil { + return err + } + + case *pgModel.Entity: + switch typedExpression.Binding.Type { + case pgModel.Node: + if _, err := WriteStrings(writer, "(", typedExpression.Binding.Symbol, ".id, ", typedExpression.Binding.Symbol, ".kind_ids, ", typedExpression.Binding.Symbol, ".properties)::nodeComposite"); err != nil { + return err + } + + case pgModel.Edge: + if _, err := WriteStrings(writer, "(", typedExpression.Binding.Symbol, ".id, ", typedExpression.Binding.Symbol, ".start_id, ", typedExpression.Binding.Symbol, ".end_id, ", typedExpression.Binding.Symbol, ".kind_id, ", typedExpression.Binding.Symbol, ".properties)::edgeComposite"); err != nil { + return err + } + + case pgModel.Path: + if _, err := WriteStrings(writer, "edges_to_path(", ")"); err != nil { + return err + } + + default: + return fmt.Errorf("unsupported entity type %s", typedExpression.Binding.Type) + } + + case *pgModel.NodeKindsReference: + if err := s.WriteExpression(writer, typedExpression.Variable); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".kind_ids"); err != nil { + return err + } + + case *pgModel.EdgeKindReference: + if err := s.WriteExpression(writer, typedExpression.Variable); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".kind_id"); err != nil { + return err + } + + case *pgModel.AnnotatedPropertyLookup: + if _, err := io.WriteString(writer, "("); err != nil { + return nil + } + + if err := s.WriteExpression(writer, typedExpression.Atom); err != nil { + return err + } + + switch typedExpression.Type { + case + // We can't directly cast from JSONB types to time types since they require parsing first. The '->>' + // operator coerces the underlying JSONB value to text before type casting + pgModel.Date, pgModel.TimeWithTimeZone, pgModel.TimeWithoutTimeZone, pgModel.TimestampWithTimeZone, pgModel.TimestampWithoutTimeZone, + + // Text types also require the `->>' operator otherwise type casting clobbers itself + pgModel.Text: + + if _, err := io.WriteString(writer, ".properties->>'"); err != nil { + return nil + } + + default: + if _, err := io.WriteString(writer, ".properties->'"); err != nil { + return nil + } + } + + if _, err := WriteStrings(writer, typedExpression.Symbols[0], "')::", typedExpression.Type.String()); err != nil { + return nil + } + + case *model.PropertyLookup: + if err := s.WriteExpression(writer, typedExpression.Atom); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".properties->'", typedExpression.Symbols[0], "'"); err != nil { + return nil + } + + case *pgModel.AnnotatedKindMatcher: + if err := s.WriteExpression(writer, typedExpression.Reference); err != nil { + return err + } + + if mappedKinds, missingKinds := s.kindMapper.MapKinds(typedExpression.Kinds); len(missingKinds) > 0 { + return fmt.Errorf("query references the following undefined kinds: %v", missingKinds.Strings()) + } else { + mappedKindStr := JoinInt(mappedKinds, ", ") + + switch typedExpression.Type { + case pgModel.Node: + if _, err := WriteStrings(writer, ".kind_ids operator(pg_catalog.&&) array[", mappedKindStr, "]::int2[]"); err != nil { + return err + } + + case pgModel.Edge: + if _, err := WriteStrings(writer, ".kind_id = any(array[", mappedKindStr, "]::int2[])"); err != nil { + return err + } + } + } + + case *model.FunctionInvocation: + if err := s.translateFunctionInvocation(writer, typedExpression); err != nil { + return err + } + + case *model.Parameter: + if _, err := WriteStrings(writer, "@", typedExpression.Symbol); err != nil { + return err + } + + case *pgModel.AnnotatedParameter: + if _, err := WriteStrings(writer, "@", typedExpression.Symbol); err != nil { + return err + } + + case *model.Parenthetical: + if _, err := WriteStrings(writer, "("); err != nil { + return err + } + + if err := s.WriteExpression(writer, typedExpression.Expression); err != nil { + return err + } + + if _, err := WriteStrings(writer, ")"); err != nil { + return err + } + + case *pgModel.PropertiesReference: + if err := s.WriteExpression(writer, typedExpression.Reference); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".properties"); err != nil { + return err + } + + case *model.ProjectionItem: + if err := s.WriteExpression(writer, typedExpression.Expression); err != nil { + return err + } + + if _, err := WriteStrings(writer, " as "); err != nil { + return err + } + + if typedExpression.Binding != nil { + if err := s.WriteExpression(writer, typedExpression.Binding); err != nil { + return err + } + } else { + if _, err := WriteStrings(writer, "\""); err != nil { + return err + } + + switch typedProjectionExpression := typedExpression.Expression.(type) { + case *pgModel.NodeKindsReference: + if err := s.WriteExpression(writer, typedProjectionExpression); err != nil { + return err + } + + case *pgModel.EdgeKindReference: + if err := s.WriteExpression(writer, typedProjectionExpression); err != nil { + return err + } + + case *model.FunctionInvocation: + if err := s.WriteExpression(writer, typedProjectionExpression); err != nil { + return err + } + + case *model.PropertyLookup: + if err := s.WriteExpression(writer, typedProjectionExpression.Atom); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".", typedProjectionExpression.Symbols[0]); err != nil { + return err + } + + case *pgModel.AnnotatedPropertyLookup: + if err := s.WriteExpression(writer, typedProjectionExpression.Atom); err != nil { + return err + } + + if _, err := WriteStrings(writer, ".", typedProjectionExpression.Symbols[0]); err != nil { + return err + } + + case *pgModel.AnnotatedVariable: + if err := s.WriteExpression(writer, typedProjectionExpression.Symbol); err != nil { + return err + } + + case *pgModel.Entity: + if err := s.WriteExpression(writer, typedProjectionExpression.Binding); err != nil { + return err + } + + default: + return fmt.Errorf("unexpected projection item for binding formatting: %T", typedExpression.Expression) + } + + if _, err := WriteStrings(writer, "\""); err != nil { + return err + } + } + default: + return fmt.Errorf("unexpected expression type for string formatting: %T", expression) + } + + return nil +} + +func (s *Emitter) translateFunctionInvocation(writer io.Writer, functionInvocation *model.FunctionInvocation) error { + switch functionInvocation.Name { + case cypherIdentityFunction: + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := io.WriteString(writer, ".id"); err != nil { + return err + } + + case cypherDateFunction: + if len(functionInvocation.Arguments) > 0 { + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := io.WriteString(writer, "::date"); err != nil { + return err + } + } else if _, err := io.WriteString(writer, "current_date"); err != nil { + return err + } + + case cypherTimeFunction: + if len(functionInvocation.Arguments) > 0 { + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := io.WriteString(writer, "::time with time zone"); err != nil { + return err + } + } else if _, err := io.WriteString(writer, "current_time"); err != nil { + return err + } + + case cypherLocalTimeFunction: + if len(functionInvocation.Arguments) > 0 { + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := io.WriteString(writer, "::time without time zone"); err != nil { + return err + } + } else if _, err := io.WriteString(writer, "localtime"); err != nil { + return err + } + + case cypherToLowerFunction: + if _, err := WriteStrings(writer, pgsqlToLowerFunction, "("); err != nil { + return err + } + + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := WriteStrings(writer, ")"); err != nil { + return err + } + + case cypherDateTimeFunction: + if len(functionInvocation.Arguments) > 0 { + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := io.WriteString(writer, "::timestamp with time zone"); err != nil { + return err + } + } else if _, err := io.WriteString(writer, "now()"); err != nil { + return err + } + + case cypherLocalDateTimeFunction: + if len(functionInvocation.Arguments) > 0 { + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := io.WriteString(writer, "::timestamp without time zone"); err != nil { + return err + } + } else if _, err := io.WriteString(writer, "localtimestamp"); err != nil { + return err + } + + case cypherCountFunction: + if _, err := WriteStrings(writer, "count("); err != nil { + return err + } + + for _, argument := range functionInvocation.Arguments { + if err := s.WriteExpression(writer, argument); err != nil { + return err + } + } + + if _, err := WriteStrings(writer, ")"); err != nil { + return err + } + + case pgsqlAnyFunction, pgsqlToJSONBFunction: + if _, err := WriteStrings(writer, functionInvocation.Name, "("); err != nil { + return err + } + + if err := s.WriteExpression(writer, functionInvocation.Arguments[0]); err != nil { + return err + } + + if _, err := io.WriteString(writer, ")"); err != nil { + return err + } + + default: + return fmt.Errorf("unsupported function invocation %s", functionInvocation.Name) + } + + return nil +} + +func (s *Emitter) Write(regularQuery *model.RegularQuery, writer io.Writer) error { + if regularQuery.SingleQuery != nil { + if regularQuery.SingleQuery.MultiPartQuery != nil { + return fmt.Errorf("not supported yet") + } + + if regularQuery.SingleQuery.SinglePartQuery != nil { + if err := s.writeSinglePartQuery(writer, regularQuery.SingleQuery.SinglePartQuery); err != nil { + return err + } + } + } + + return nil +} diff --git a/packages/go/cypher/backend/pgsql/format_test.go b/packages/go/cypher/backend/pgsql/format_test.go new file mode 100644 index 0000000000..e8d972ac45 --- /dev/null +++ b/packages/go/cypher/backend/pgsql/format_test.go @@ -0,0 +1,710 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql_test + +import ( + "bytes" + "fmt" + "github.com/jackc/pgtype" + "github.com/specterops/bloodhound/cypher/backend/pgsql" + "github.com/specterops/bloodhound/cypher/frontend" + "github.com/specterops/bloodhound/cypher/model" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/dawgs/query" + "github.com/specterops/bloodhound/graphschema/ad" + "github.com/specterops/bloodhound/graphschema/common" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func MustMarshalToJSONB(value any) *pgtype.JSONB { + jsonb := &pgtype.JSONB{} + + if err := jsonb.Set(value); err != nil { + panic(fmt.Sprintf("Unable to marshal value type %T to JSONB: %v", value, err)) + } + + return jsonb +} + +type KindMapper struct { + known map[string]int16 +} + +func (s KindMapper) MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { + var ( + kindIDs = make([]int16, 0, len(kinds)) + missingKinds = make([]graph.Kind, 0, len(kinds)) + ) + + for _, kind := range kinds { + if kindID, hasKind := s.known[kind.String()]; hasKind { + kindIDs = append(kindIDs, kindID) + } else { + missingKinds = append(missingKinds, kind) + } + } + + return kindIDs, missingKinds +} + +type TestCase struct { + ID int + Source string + Query *query.Builder + Expected string + ExpectedParameters map[string]any + Exclusive bool + Ignored bool + Error bool +} + +func Suite() []TestCase { + return []TestCase{ + { + ID: 1, + Source: "match (s) return s skip 5 limit 10", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s offset 5 limit 10", + }, + { + ID: 2, + Source: "match (s) return s order by s.name, s.other_prop desc", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s order by s.properties->'name' asc, s.properties->'other_prop' desc", + }, + { + ID: 3, + Source: "match (s) where (s)-[]->() return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where exists(select * from node as n2 join edge e0 on e0.start_id = n2.id join node n1 on n1.id = e0.end_id where s.id = n2.id limit 1)", + }, + { + ID: 4, + Source: "match ()-[r]->() where (s {name: 'test'})-[r]->() return r", + Expected: "select (r.id, r.start_id, r.end_id, r.kind_id, r.properties)::edgeComposite as r from node as n1 join edge r on r.start_id = n1.id join node n2 on n2.id = r.end_id where exists(select * from node as s join edge e3 on e3.start_id = s.id join node n0 on n0.id = e3.end_id where (s.properties->>'name')::text = 'test' and r.id = e3.id limit 1)", + }, + { + ID: 5, + Source: "match (s {value: 'PII'})-[r {other: 234}]->(e {that: 456}) where s.other = 'more pii' and e.number = 411 return s, r, e", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s, (r.id, r.start_id, r.end_id, r.kind_id, r.properties)::edgeComposite as r, (e.id, e.kind_ids, e.properties)::nodeComposite as e from node as s join edge r on r.start_id = s.id join node e on e.id = r.end_id where (s.properties->>'value')::text = 'PII' and (r.properties->'other')::int8 = 234 and (e.properties->'that')::int8 = 456 and (s.properties->>'other')::text = 'more pii' and (e.properties->'number')::int8 = 411", + }, + { + ID: 6, + Source: "match (s)-[r:EdgeKindA|EdgeKindB]->(e) return s.name, e.name", + Expected: "select s.properties->'name' as \"s.name\", e.properties->'name' as \"e.name\" from node as s join edge r on r.start_id = s.id join node e on e.id = r.end_id where r.kind_id = any(array[100, 101]::int2[])", + }, + { + ID: 7, + Source: "match (s)<-[r:EdgeKindA|EdgeKindB]-(e) return s.name, e.name", + Expected: "select s.properties->'name' as \"s.name\", e.properties->'name' as \"e.name\" from node as s join edge r on r.end_id = s.id join node e on e.id = r.start_id where r.kind_id = any(array[100, 101]::int2[])", + }, + { + ID: 8, + Source: "match (s)-[:EdgeKindA|EdgeKindB]->(e)-[:EdgeKindA]->() return s.name, e.name", + Expected: "select s.properties->'name' as \"s.name\", e.properties->'name' as \"e.name\" from node as s join edge e0 on e0.start_id = s.id join node e on e.id = e0.end_id join edge e1 on e1.start_id = e.id join node n2 on n2.id = e1.end_id where e0.kind_id = any(array[100, 101]::int2[]) and e1.kind_id = any(array[100]::int2[])", + }, + { + ID: 9, + Source: "match (s:NodeKindA)-[r:EdgeKindA|EdgeKindB]->(e:NodeKindB) return s.name, e.name", + Expected: "select s.properties->'name' as \"s.name\", e.properties->'name' as \"e.name\" from node as s join edge r on r.start_id = s.id join node e on e.id = r.end_id where s.kind_ids operator(pg_catalog.&&) array[1]::int2[] and r.kind_id = any(array[100, 101]::int2[]) and e.kind_ids operator(pg_catalog.&&) array[2]::int2[]", + }, + { + ID: 10, + Source: "match (s) where s.name = '123' return s.name", + Expected: "select s.properties->'name' as \"s.name\" from node as s where (s.properties->>'name')::text = '123'", + }, + { + ID: 11, + Source: "match (s:NodeKindA), (o:NodeKindB) where s.objectid = '123' and o.linked = s.linkid return o", + Expected: "select (o.id, o.kind_ids, o.properties)::nodeComposite as o from node as s, node as o where s.kind_ids operator(pg_catalog.&&) array[1]::int2[] and o.kind_ids operator(pg_catalog.&&) array[2]::int2[] and (s.properties->>'objectid')::text = '123' and o.properties->'linked' = s.properties->'linkid'", + }, + { + ID: 12, + Source: "match (s) where s.name in ['option 1', 'option 2'] return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'name')::text in array['option 1', 'option 2']", + }, + { + ID: 13, + Source: "match (s) where id(s) in [1, 2, 3, 4] return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where s.id in array[1, 2, 3, 4]", + }, + { + ID: 14, + Source: "match (s) where s.created_at = localtime() return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::time without time zone = localtime", + }, + { + ID: 15, + Source: "match (s) where s.created_at = localtime('12:12:12') return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::time without time zone = '12:12:12'::time without time zone", + }, + { + ID: 16, + Source: "match (s) where s.created_at = date() return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::date = current_date", + }, + { + ID: 17, + Source: "match (s) where s.created_at = date('2023-12-12') return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::date = '2023-12-12'::date", + }, + { + ID: 18, + Source: "match (s) where s.created_at = datetime() return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::timestamp with time zone = now()", + }, + { + ID: 19, + Source: "match (s) where s.name = '1234' return count(s) as num", + Expected: "select count(s) as num from node as s where (s.properties->>'name')::text = '1234'", + }, + { + ID: 20, + Source: "match (s) where s.created_at = datetime('2019-06-01T18:40:32.142+0100') return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::timestamp with time zone = '2019-06-01T18:40:32.142+0100'::timestamp with time zone", + }, + { + ID: 21, + Source: "match (s) where not (s.name = '123') return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where not ((s.properties->>'name')::text = '123')", + }, + { + ID: 22, + Source: "match (s) where s.created_at = localdatetime() return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::timestamp without time zone = localtimestamp", + }, + { + ID: 23, + Source: "match (s) where s.created_at = localdatetime('2019-06-01T18:40:32.142') return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'created_at')::timestamp without time zone = '2019-06-01T18:40:32.142'::timestamp without time zone", + }, + { + ID: 24, + Source: "match (s) where s.created_at is null return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where not s.properties ? 'created_at'", + }, + { + ID: 25, + Source: "match (s) where s.created_at is not null return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where s.properties ? 'created_at'", + }, + { + ID: 26, + Source: "match (s) where s:NodeKindA return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where s.kind_ids operator(pg_catalog.&&) array[1]::int2[]", + }, + { + ID: 27, + Source: "match (s) where s.name starts with '123' return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'name')::text like '123%'", + }, + { + ID: 28, + Source: "match (s) where s.name contains '123' return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'name')::text like '%123%'", + }, + { + ID: 29, + Source: "match (s) where s.name ends with '123' return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where (s.properties->>'name')::text like '%123'", + }, + { + ID: 30, + Source: "match (s) where s:NodeKindA return s", + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where s.kind_ids operator(pg_catalog.&&) array[1]::int2[]", + }, + { + ID: 31, + Source: "match (s) where s:NodeKindA return distinct s", + Expected: "select distinct (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where s.kind_ids operator(pg_catalog.&&) array[1]::int2[]", + }, + { + ID: 32, + Source: "match (s) where toLower(s.name) = '1234' return distinct s", + Expected: "select distinct (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s where lower((s.properties->>'name')::text) = '1234'", + }, + { + ID: 33, + Source: "match (s) where s.name = '1234' return labels(s)", + Expected: "select s.kind_ids as \"s.kind_ids\" from node as s where (s.properties->>'name')::text = '1234'", + }, + { + ID: 34, + Source: "match ()-[r]->() where r.name = '1234' return type(r)", + Expected: "select r.kind_id as \"r.kind_id\" from node as n0 join edge r on r.start_id = n0.id join node n1 on n1.id = r.end_id where (r.properties->>'name')::text = '1234'", + }, + { + ID: 35, + Ignored: true, + Source: "match p = (s:NodeKindA)-[:EdgeKindA*..]->(:NodeKindB) where id(s) = 5 return p", + Expected: "select edges_to_path() as p from node as s join edge e0 on e0.start_id = s.id join node n1 on n1.id = e0.end_id where s.kind_ids operator(pg_catalog.&&) array[1]::int2[] and e0.kind_id = any(array[100]::int2[]) and n1.kind_ids operator(pg_catalog.&&) array[2]::int2[] and s.id = 5", + }, + { + ID: 36, + Source: "match (s) where s.created_at = localtime() delete s", + Expected: "delete from node as s where (s.properties->>'created_at')::time without time zone = localtime", + }, + { + ID: 37, + Source: "match (s) where s.created_at = localtime() detach delete s", + Expected: "delete from node as s where (s.properties->>'created_at')::time without time zone = localtime", + }, + { + ID: 38, + Source: "match ()-[r]->() where r.name = '1234' delete r", + Expected: "delete from edge as r using node as n0, node as n1 where (r.properties->>'name')::text = '1234' and n0.id = r.start_id and n1.id = r.end_id", + }, + { + ID: 39, + Source: "match (s)-[r]->() where s.name = '1234' delete r", + Expected: "delete from edge as r using node as s, node as n0 where (s.properties->>'name')::text = '1234' and s.id = r.start_id and n0.id = r.end_id", + }, + { + ID: 40, + Source: "match (s)-[r]->(e) where s.name = '1234' delete r", + Expected: "delete from edge as r using node as s, node as e where (s.properties->>'name')::text = '1234' and s.id = r.start_id and e.id = r.end_id", + }, + { + ID: 41, + Source: "match ()-[r]->(e) where e.name = '1234' delete r", + Expected: "delete from edge as r using node as n0, node as e where (e.properties->>'name')::text = '1234' and n0.id = r.start_id and e.id = r.end_id", + }, + { + ID: 42, + Source: "match ()-[r:EdgeKindA]->(e) delete e", + Expected: "delete from node as e using node as n0, edge as r where r.kind_id = any(array[100]::int2[]) and n0.id = r.start_id and e.id = r.end_id", + }, + { + ID: 43, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.InIDs(query.Node(), 1, 2, 3), + )), + query.Returning(query.Node()), + ), + Expected: "select (n.id, n.kind_ids, n.properties)::nodeComposite as n from node as n where n.id = any(@p0)", + }, + { + ID: 44, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.In(query.NodeProperty("prop"), []string{"1", "2", "3"}), + )), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where (n.properties->>'prop')::text = any(@p0)", + }, + { + ID: 45, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.In(query.NodeProperty("prop"), []int16{1, 2, 3}), + )), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where (n.properties->'prop')::int2 = any(@p0)", + }, + { + ID: 46, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.In(query.NodeProperty("prop"), []int32{1, 2, 3}), + )), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where (n.properties->'prop')::int4 = any(@p0)", + }, + { + ID: 47, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.In(query.NodeProperty("prop"), []int64{1, 2, 3}), + )), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where (n.properties->'prop')::int8 = any(@p0)", + }, + { + ID: 48, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.Kind(query.Relationship(), graph.StringKind("EdgeKindA")), + query.Kind(query.End(), graph.StringKind("NodeKindA")), + query.In(query.EndProperty(common.ObjectID.String()), []string{"12345", "23456"}), + )), + query.Delete(query.Relationship()), + ), + + Expected: "delete from edge as r using node as n0, node as e where r.kind_id = any(array[100]::int2[]) and e.kind_ids operator(pg_catalog.&&) array[1]::int2[] and (e.properties->>'objectid')::text = any(@p0) and n0.id = r.start_id and e.id = r.end_id", + }, + { + ID: 49, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.Kind(query.Node(), graph.StringKind("NodeKindA")), + query.StringContains(query.NodeProperty(common.OperatingSystem.String()), "WINDOWS"), + query.Exists(query.NodeProperty(common.PasswordLastSet.String())), + )), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where n.kind_ids operator(pg_catalog.&&) array[1]::int2[] and (n.properties->>'operatingsystem')::text like @p0 and n.properties ? 'pwdlastset'", + }, + { + ID: 50, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.KindIn(query.Node(), graph.StringKind("NodeKindA"), graph.StringKind("NodeKindB")), + query.StringEndsWith(query.NodeProperty(common.ObjectID.String()), "-5-1-9"), + query.Equals(query.NodeProperty(ad.DomainSID.String()), "DOMAINSID"), + )), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where (n.kind_ids operator(pg_catalog.&&) array[1, 2]::int2[]) and (n.properties->>'objectid')::text like @p0 and (n.properties->>'domainsid')::text = @p1", + ExpectedParameters: map[string]any{ + "p0": "%-5-1-9", + "p1": "DOMAINSID", + }, + }, + { + ID: 51, + Query: query.NewBuilderWithCriteria( + query.Where( + query.KindIn(query.Relationship(), graph.StringKind("EdgeKindA"), graph.StringKind("EdgeKindB")), + ), + query.Returning(query.Start()), + ), + Expected: "select (s.id, s.kind_ids, s.properties)::nodeComposite as s from node as s join edge r on r.start_id = s.id join node n0 on n0.id = r.end_id where (r.kind_id = any(array[100, 101]::int2[]))", + }, + { + ID: 52, + Query: query.NewBuilderWithCriteria( + query.Where( + query.Not(query.HasRelationships(query.Node())), + ), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where not exists(select * from node as n2 join edge e0 on e0.start_id = n2.id or e0.end_id = n2.id join node n1 on n1.id = e0.start_id or n1.id = e0.end_id where n.id = n2.id limit 1)", + }, + { + ID: 53, + Query: query.NewBuilderWithCriteria( + query.Where(query.And( + query.In(query.NodeProperty("prop"), []float32{1, 2, 3}), + )), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where (n.properties->'prop')::float4 = any(@p0)", + }, + { + ID: 54, + Query: query.NewBuilderWithCriteria( + query.Where( + query.And( + query.In(query.NodeProperty("prop"), []float64{1, 2, 3}), + ), + ), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where (n.properties->'prop')::float8 = any(@p0)", + }, + { + ID: 55, + Query: query.NewBuilderWithCriteria( + query.Where( + query.And( + query.KindIn(query.Relationship(), graph.StringKind("EdgeKindA"), graph.StringKind("EdgeKindB")), + query.Or( + query.Not(query.Exists(query.RelationshipProperty(common.LastSeen.String()))), + query.Before(query.RelationshipProperty(common.LastSeen.String()), time.Date(2023, time.August, 01, 0, 0, 0, 0, time.Local)), + ), + ), + ), + query.Returning(query.Relationship()), + ), + Expected: "select (r.id, r.start_id, r.end_id, r.kind_id, r.properties)::edgeComposite as r from node as n0 join edge r on r.start_id = n0.id join node n1 on n1.id = r.end_id where (r.kind_id = any(array[100, 101]::int2[])) and (not r.properties ? 'lastseen' or (r.properties->>'lastseen')::timestamp with time zone < @p0)", + }, + { + ID: 56, + Query: query.NewBuilderWithCriteria( + query.Where( + query.And( + query.Kind(query.Node(), graph.StringKind("NodeKindA")), + query.Or( + query.Equals(query.NodeProperty("name"), "12345"), + query.Equals(query.NodeProperty("objectid"), "12345"), + ), + query.Not( + query.And( + query.Kind(query.Node(), graph.StringKind("NodeKindB")), + query.Not(query.Kind(query.Node(), graph.StringKind("NodeKindC"))), + ), + ), + ), + ), + query.Delete(query.Node()), + ), + Expected: "delete from node as n where n.kind_ids operator(pg_catalog.&&) array[1]::int2[] and ((n.properties->>'name')::text = @p0 or (n.properties->>'objectid')::text = @p1) and not (n.kind_ids operator(pg_catalog.&&) array[2]::int2[] and not n.kind_ids operator(pg_catalog.&&) array[3]::int2[])", + }, + { + ID: 57, + Query: query.NewBuilderWithCriteria( + query.Where( + query.And( + query.Kind(query.Node(), graph.StringKind("NodeKindA")), + query.Or( + query.StringContains(query.NodeProperty("name"), "name"), + query.StringContains(query.NodeProperty("objectid"), "name"), + ), + query.Not(query.Equals(query.NodeProperty("name"), "name")), + query.Not(query.Equals(query.NodeProperty("objectid"), "name")), + query.Not( + query.And( + query.Kind(query.Node(), graph.StringKind("NodeKindB")), + query.Not(query.Kind(query.Node(), graph.StringKind("NodeKindC"))), + ), + ), + ), + ), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where n.kind_ids operator(pg_catalog.&&) array[1]::int2[] and ((n.properties->>'name')::text like @p0 or (n.properties->>'objectid')::text like @p1) and not (n.properties->>'name')::text = @p2 and not (n.properties->>'objectid')::text = @p3 and not (n.kind_ids operator(pg_catalog.&&) array[2]::int2[] and not n.kind_ids operator(pg_catalog.&&) array[3]::int2[])", + ExpectedParameters: map[string]any{ + "p0": "%name%", + "p1": "%name%", + "p2": "name", + "p3": "name", + }, + }, + { + ID: 57, + Query: query.NewBuilderWithCriteria( + query.Where( + query.And( + query.Kind(query.Node(), graph.StringKind("NodeKindA")), + query.Equals(query.NodeProperty(common.ObjectID.String()), "67CE0FEC-166C-4E5E-BF87-6FBAF0E9C8A8"), + query.Equals(query.NodeProperty(common.Name.String()), "CLIENTAUTH@ESC1.LOCAL"), + query.Equals(query.NodeProperty(ad.DomainSID.String()), "S-1-5-21-909015691-3030120388-2582151266"), + query.Equals(query.NodeProperty(ad.DistinguishedName.String()), "CN=CLIENTAUTH,CN=CERTIFICATE TEMPLATES,CN=PUBLIC KEY SERVICES,CN=SERVICES,CN=CONFIGURATION,DC=ESC1,DC=LOCAL"), + query.Equals(query.NodeProperty(ad.ValidityPeriod.String()), "1 year"), + query.Equals(query.NodeProperty(ad.RenewalPeriod.String()), "6 weeks"), + query.Equals(query.NodeProperty(ad.SchemaVersion.String()), 1), + query.Equals(query.NodeProperty(ad.OID.String()), "1.3.6.1.4.1.311.21.8.12059088.7148202.5130407.12905872.6174753.77.1.4"), + query.Equals(query.NodeProperty(ad.EnrollmentFlag.String()), "AUTO_ENROLLMENT"), + query.Equals(query.NodeProperty(ad.RequiresManagerApproval.String()), false), + query.Equals(query.NodeProperty(ad.NoSecurityExtension.String()), false), + query.Equals(query.NodeProperty(ad.CertificateNameFlag.String()), "SUBJECT_ALT_REQUIRE_UPN, SUBJECT_REQUIRE_DIRECTORY_PATH"), + query.Equals(query.NodeProperty(ad.EnrolleeSuppliesSubject.String()), false), + query.Equals(query.NodeProperty(ad.SubjectAltRequireUPN.String()), true), + query.Equals(query.NodeProperty(ad.EKUs.String()), []string{"1.3.6.1.5.5.7.3.2"}), + query.Equals(query.NodeProperty(ad.CertificateApplicationPolicy.String()), []string{}), + query.Equals(query.NodeProperty(ad.AuthorizedSignatures.String()), 0), + query.Equals(query.NodeProperty(ad.ApplicationPolicies.String()), []string{}), + query.Equals(query.NodeProperty(ad.IssuancePolicies.String()), []string{}), + query.Equals(query.NodeProperty(ad.EffectiveEKUs.String()), []string{"1.3.6.1.5.5.7.3.2"}), + query.Equals(query.NodeProperty(ad.AuthenticationEnabled.String()), true)), + ), + query.Returning(query.NodeID()), + ), + Expected: "select n.id as \"n.id\" from node as n where n.kind_ids operator(pg_catalog.&&) array[1]::int2[] and (n.properties->>'objectid')::text = @p0 and (n.properties->>'name')::text = @p1 and (n.properties->>'domainsid')::text = @p2 and (n.properties->>'distinguishedname')::text = @p3 and (n.properties->>'validityperiod')::text = @p4 and (n.properties->>'renewalperiod')::text = @p5 and (n.properties->'schemaversion')::int8 = @p6 and (n.properties->>'oid')::text = @p7 and (n.properties->>'enrollmentflag')::text = @p8 and (n.properties->'requiresmanagerapproval')::bool = @p9 and (n.properties->'nosecurityextension')::bool = @p10 and (n.properties->>'certificatenameflag')::text = @p11 and (n.properties->'enrolleesuppliessubject')::bool = @p12 and (n.properties->'subjectaltrequireupn')::bool = @p13 and (n.properties->'ekus')::jsonb = @p14 and (n.properties->'certificateapplicationpolicy')::jsonb = @p15 and (n.properties->'authorizedsignatures')::int8 = @p16 and (n.properties->'applicationpolicies')::jsonb = @p17 and (n.properties->'issuancepolicies')::jsonb = @p18 and (n.properties->'effectiveekus')::jsonb = @p19 and (n.properties->'authenticationenabled')::bool = @p20", + ExpectedParameters: map[string]interface{}{ + "p0": "67CE0FEC-166C-4E5E-BF87-6FBAF0E9C8A8", + "p1": "CLIENTAUTH@ESC1.LOCAL", + "p10": false, + "p11": "SUBJECT_ALT_REQUIRE_UPN, SUBJECT_REQUIRE_DIRECTORY_PATH", + "p12": false, + "p13": true, + "p14": MustMarshalToJSONB([]string{"1.3.6.1.5.5.7.3.2"}), + "p15": MustMarshalToJSONB([]string{}), + "p16": 0, + "p17": MustMarshalToJSONB([]string{}), + "p18": MustMarshalToJSONB([]string{}), + "p19": MustMarshalToJSONB([]string{"1.3.6.1.5.5.7.3.2"}), + "p2": "S-1-5-21-909015691-3030120388-2582151266", + "p20": true, + "p3": "CN=CLIENTAUTH,CN=CERTIFICATE TEMPLATES,CN=PUBLIC KEY SERVICES,CN=SERVICES,CN=CONFIGURATION,DC=ESC1,DC=LOCAL", + "p4": "1 year", + "p5": "6 weeks", + "p6": 1, + "p7": "1.3.6.1.4.1.311.21.8.12059088.7148202.5130407.12905872.6174753.77.1.4", + "p8": "AUTO_ENROLLMENT", + "p9": false}, + }, + + // UPDATE CASES + { + ID: 158, + Source: "match (s) where s:NodeKindA set s:NodeKindB return s", + Expected: "update node as s set kind_ids = kind_ids || @p0 where s.kind_ids operator(pg_catalog.&&) array[1]::int2[] returning (s.id, s.kind_ids, s.properties)::nodeComposite as s", + }, + { + ID: 159, + Source: "match (s) where s:NodeKindA set s:NodeKindB remove s:NodeKindA return s", + Expected: "update node as s set kind_ids = kind_ids - @p1 || @p0 where s.kind_ids operator(pg_catalog.&&) array[1]::int2[] returning (s.id, s.kind_ids, s.properties)::nodeComposite as s", + }, + { + ID: 160, + Source: "match (s) set s.name = 'new name', s:NodeKindA return s", + Expected: "update node as s set properties = properties || @p0, kind_ids = kind_ids || @p1 returning (s.id, s.kind_ids, s.properties)::nodeComposite as s", + }, + { + ID: 161, + Source: "match (s) where s:NodeKindA set s.name = 'new name' return s", + Expected: "update node as s set properties = properties || @p0 where s.kind_ids operator(pg_catalog.&&) array[1]::int2[] returning (s.id, s.kind_ids, s.properties)::nodeComposite as s", + }, + { + ID: 162, + Source: "match (s) where s:NodeKindA set s.name = 'lol' remove s.other return s", + Expected: "update node as s set properties = properties - @p1::text[] || @p0 where s.kind_ids operator(pg_catalog.&&) array[1]::int2[] returning (s.id, s.kind_ids, s.properties)::nodeComposite as s", + }, + + // TODO: This is commented because all shortest paths is not directly supported by the cypher-to-pg translation + // but should be. Future effort should enable this test case as native pathfinding in the pg database + // is now formally supported. + //{ + // ID: 63, + // Source: "match p = allShortestPaths((:NodeKindA)-[:EdgeKindA*..]->(:NodeKindB)) return p", + // Expected: "", + // Exclusive: true, + //}, + + // ERROR CASES + + // Mixed types in a list match should fail. Once a field type is set there must be no ambiguity. + { + ID: 200, + Source: "match (s) where s.name in ['option 1', 'option 2', 1234] return s", + Error: true, + }, + + // UNSUPPORTED CASES + + // The following queries are going to require running each match as a distinct select statements with a left + // outer join to combine result sets. This is pretty ill-defined and a stupid feature if you ask me, so I'm + // going to leave it out for now. + { + ID: 300, + Source: "match (s), (e)-[]->(o) where s.name = '123' and e.name = 'lol' return s.name, e, o", + Ignored: true, + }, + { + ID: 301, + Source: "match (s) where s.name = '123' match (e) where e.name = 'lol' return s.name, e", + Ignored: true, + }, + } +} + +func TestPGSQLEmitter(t *testing.T) { + var ( + runnable []TestCase + exclusiveRun bool + ) + + for _, testCase := range Suite() { + if testCase.Ignored { + continue + } + + if testCase.Exclusive { + if !exclusiveRun { + runnable = runnable[:0] + exclusiveRun = true + } + + runnable = append(runnable, testCase) + } else if !exclusiveRun { + runnable = append(runnable, testCase) + } + } + + for _, testCase := range runnable { + var regularQuery *model.RegularQuery + + if testCase.Query != nil { + builtQuery, err := testCase.Query.Build() + require.Nilf(t, err, "test case %d: %v", testCase.ID, err) + + regularQuery = builtQuery + + } else { + parsedQuery, parseErr := frontend.ParseCypher(frontend.NewContext(), testCase.Source) + require.Nilf(t, parseErr, "test case %d: %v", testCase.ID, parseErr) + + regularQuery = parsedQuery + } + + var ( + buffer = &bytes.Buffer{} + kindMapper = KindMapper{ + known: map[string]int16{ + "NodeKindA": 1, + "NodeKindB": 2, + "NodeKindC": 3, + "EdgeKindA": 100, + "EdgeKindB": 101, + "EdgeKindC": 102, + }, + } + + parameters, translationErr = pgsql.Translate(regularQuery, kindMapper) + ) + + if testCase.Error { + if translationErr != nil { + continue + } + + var ( + emitter = pgsql.NewEmitter(false, kindMapper) + emitterErr = emitter.Write(regularQuery, buffer) + ) + + require.NotNilf(t, emitterErr, "test case %d: %v", testCase.ID, emitterErr) + } else { + require.Nilf(t, translationErr, "test case %d: %v", testCase.ID, translationErr) + + if testCase.ExpectedParameters != nil { + require.Equal(t, testCase.ExpectedParameters, parameters) + } + + var ( + emitter = pgsql.NewEmitter(false, kindMapper) + emitterErr = emitter.Write(regularQuery, buffer) + ) + + require.Nilf(t, emitterErr, "test case %d: %v", testCase.ID, emitterErr) + require.Equalf(t, testCase.Expected, buffer.String(), "test case %d", testCase.ID) + } + } +} + +func TestBinder(t *testing.T) { + var ( + binder = pgsql.NewBinder() + regularQuery, parseErr = frontend.ParseCypher(frontend.DefaultCypherContext(), "match (s) with s as m return s") + binderErr = binder.Scan(regularQuery) + ) + + require.Nil(t, parseErr) + require.Nil(t, binderErr) + + require.True(t, binder.IsBound("s")) + require.True(t, binder.IsPatternBinding("s")) + require.True(t, binder.IsBound("m")) + + // TODO: This might want to be true depending on how references play out during joins + require.False(t, binder.IsPatternBinding("m")) +} diff --git a/packages/go/cypher/backend/pgsql/model.go b/packages/go/cypher/backend/pgsql/model.go new file mode 100644 index 0000000000..3327c4721b --- /dev/null +++ b/packages/go/cypher/backend/pgsql/model.go @@ -0,0 +1,424 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql + +import ( + "fmt" + cypherModel "github.com/specterops/bloodhound/cypher/model" + "github.com/specterops/bloodhound/cypher/model/pg" + + "github.com/specterops/bloodhound/dawgs/graph" +) + +const ( + OperatorJSONBFieldExists cypherModel.Operator = "?" + OperatorLike cypherModel.Operator = "like" + OperatorLikeCaseInsensitive cypherModel.Operator = "ilike" +) + +type UpdatingClauseRewriter struct { + kindMapper KindMapper + binder *Binder + deletion *pg.Delete + propertyReferenceSymbols map[string]struct{} + propertyAdditions map[string]map[string]any + propertyRemovals map[string][]string + kindReferenceSymbols map[string]struct{} + kindRemovals map[string][]graph.Kind + kindAdditions map[string][]graph.Kind +} + +func NewUpdateClauseRewriter(binder *Binder, kindMapper KindMapper) *UpdatingClauseRewriter { + return &UpdatingClauseRewriter{ + kindMapper: kindMapper, + binder: binder, + deletion: pg.NewDelete(), + propertyReferenceSymbols: map[string]struct{}{}, + propertyAdditions: map[string]map[string]any{}, + propertyRemovals: map[string][]string{}, + kindReferenceSymbols: map[string]struct{}{}, + kindRemovals: map[string][]graph.Kind{}, + kindAdditions: map[string][]graph.Kind{}, + } +} + +func (s *UpdatingClauseRewriter) newPropertyMutation(symbol string) (*pg.PropertyMutation, error) { + if annotatedVariable, isBound := s.binder.LookupVariable(symbol); !isBound { + return nil, fmt.Errorf("mutation variable reference %s is not bound", symbol) + } else { + return &pg.PropertyMutation{ + Reference: &pg.PropertiesReference{ + Reference: annotatedVariable, + }, + }, nil + } +} + +func (s *UpdatingClauseRewriter) newKindMutation(symbol string) (*pg.KindMutation, error) { + if annotatedVariable, isBound := s.binder.LookupVariable(symbol); !isBound { + return nil, fmt.Errorf("mutation variable reference %s is not bound", symbol) + } else { + return &pg.KindMutation{ + Variable: annotatedVariable, + }, nil + } +} + +func (s *UpdatingClauseRewriter) ToUpdatingClause() ([]cypherModel.Expression, error) { + var updatingClauses []cypherModel.Expression + + if s.deletion.NodeDelete || s.deletion.EdgeDelete { + updatingClauses = append(updatingClauses, s.deletion) + } + + for referenceSymbol := range s.propertyReferenceSymbols { + propertyMutation, err := s.newPropertyMutation(referenceSymbol) + + if err != nil { + return nil, err + } + + if propertyAdditions, hasPropertyAdditions := s.propertyAdditions[referenceSymbol]; hasPropertyAdditions { + if propertyAdditionsJSONB, err := MapStringAnyToJSONB(propertyAdditions); err != nil { + return nil, err + } else if newParameter, err := s.binder.NewParameter(propertyAdditionsJSONB); err != nil { + return nil, err + } else { + propertyMutation.Additions = newParameter + } + } + + if propertyRemovals, hasPropertyRemovals := s.propertyRemovals[referenceSymbol]; hasPropertyRemovals { + if propertyRemovalsTextArray, err := StringSliceToTextArray(propertyRemovals); err != nil { + return nil, err + } else if newParameter, err := s.binder.NewParameter(propertyRemovalsTextArray); err != nil { + return nil, err + } else { + propertyMutation.Removals = newParameter + } + } + + updatingClauses = append(updatingClauses, propertyMutation) + } + + for referenceSymbol := range s.kindReferenceSymbols { + kindMutation, err := s.newKindMutation(referenceSymbol) + + if err != nil { + return nil, err + } + + if kindAdditions, hasKindAdditions := s.kindAdditions[referenceSymbol]; hasKindAdditions { + if kindInt2Array, missingKinds := s.kindMapper.MapKinds(kindAdditions); len(missingKinds) > 0 { + return nil, fmt.Errorf("updating clause references the following unknown kinds: %v", missingKinds.Strings()) + } else if newParameter, err := s.binder.NewParameter(kindInt2Array); err != nil { + return nil, err + } else { + kindMutation.Additions = newParameter + } + } + + if kindRemovals, hasKindRemovals := s.kindRemovals[referenceSymbol]; hasKindRemovals { + if kindInt2Array, missingKinds := s.kindMapper.MapKinds(kindRemovals); len(missingKinds) > 0 { + return nil, fmt.Errorf("updating clause references the following unknown kinds: %v", missingKinds.Strings()) + } else if newParameter, err := s.binder.NewParameter(kindInt2Array); err != nil { + return nil, err + } else { + kindMutation.Removals = newParameter + } + } + + updatingClauses = append(updatingClauses, kindMutation) + } + + return updatingClauses, nil +} + +func (s *UpdatingClauseRewriter) rewriteDeleteClause(singlePartQuery *cypherModel.SinglePartQuery, deleteClause *cypherModel.Delete) error { + for _, deleteStatementExpression := range deleteClause.Expressions { + switch typedExpression := deleteStatementExpression.(type) { + case *pg.AnnotatedVariable: + switch typedExpression.Type { + case pg.Node: + if s.deletion.NodeDelete { + return fmt.Errorf("multiple node delete statements are not supported") + } + + s.deletion.Binding = typedExpression + s.deletion.NodeDelete = true + + case pg.Edge: + if s.deletion.EdgeDelete { + return fmt.Errorf("multiple edge delete statements are not supported") + } + + s.deletion.Binding = typedExpression + s.deletion.EdgeDelete = true + + default: + return fmt.Errorf("unexpected variable type: %s", typedExpression.Type.String()) + } + + default: + return fmt.Errorf("unexpected expression for delete: %T", deleteStatementExpression) + } + } + + if s.deletion.IsMixed() { + return fmt.Errorf("mixed deletions are not supported") + } + + for _, readingClause := range singlePartQuery.ReadingClauses { + if matchClause := readingClause.Match; matchClause != nil { + var additionalWhereClauses []cypherModel.Expression + + for _, pattern := range matchClause.Pattern { + if len(pattern.PatternElements) <= 1 { + // This pattern does not have a relationship and therefore no joining criteria is required + continue + } + + for idx, patternElement := range pattern.PatternElements { + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + var ( + lastNode = idx+1 >= len(pattern.PatternElements) + relBinding *pg.AnnotatedVariable + direction graph.Direction + ) + + if !lastNode { + // Look forward to the next relationship pattern + relPattern, _ := pattern.PatternElements[idx+1].AsRelationshipPattern() + direction = relPattern.Direction + + switch typedBinding := relPattern.Binding.(type) { + case *pg.AnnotatedVariable: + relBinding = typedBinding + default: + return fmt.Errorf("unexpected variable for relationship pattern binding: %T", relPattern.Binding) + } + } else { + // Look backward to the last relationship pattern + relPattern, _ := pattern.PatternElements[idx-1].AsRelationshipPattern() + direction, _ = relPattern.Direction.Reverse() + + switch typedBinding := relPattern.Binding.(type) { + case *pg.AnnotatedVariable: + relBinding = typedBinding + default: + return fmt.Errorf("unexpected variable for relationship pattern binding: %T", relPattern.Binding) + } + } + + switch direction { + case graph.DirectionInbound: + bindingCopy := pg.Copy(relBinding) + bindingCopy.Symbol += ".end_id" + + additionalWhereClauses = append(additionalWhereClauses, cypherModel.NewComparison( + cypherModel.NewSimpleFunctionInvocation(cypherIdentityFunction, nodePattern.Binding), + cypherModel.OperatorEquals, + bindingCopy, + )) + + case graph.DirectionOutbound: + bindingCopy := pg.Copy(relBinding) + bindingCopy.Symbol += ".start_id" + + additionalWhereClauses = append(additionalWhereClauses, cypherModel.NewComparison( + cypherModel.NewSimpleFunctionInvocation(cypherIdentityFunction, nodePattern.Binding), + cypherModel.OperatorEquals, + bindingCopy, + )) + + default: + return fmt.Errorf("invalid pattern direction: %d", direction) + } + } + } + } + + if len(additionalWhereClauses) > 0 { + additionalWhereClause := cypherModel.NewConjunction(additionalWhereClauses...) + + if matchClause.Where == nil { + matchClause.Where = cypherModel.NewWhere() + } + + if len(matchClause.Where.Expressions) > 0 { + matchClause.Where.Expressions = []cypherModel.Expression{ + cypherModel.NewConjunction(append(matchClause.Where.Expressions, additionalWhereClause)...), + } + } else { + matchClause.Where.Add(additionalWhereClause) + } + } + } + } + + return nil +} + +func (s *UpdatingClauseRewriter) RewriteUpdatingClauses(singlePartQuery *cypherModel.SinglePartQuery) error { + for _, updatingClause := range singlePartQuery.UpdatingClauses { + typedUpdatingClause, isUpdatingClause := updatingClause.(*cypherModel.UpdatingClause) + + if !isUpdatingClause { + return fmt.Errorf("unexpected type for updating clause: %T", updatingClause) + } + + switch typedClause := typedUpdatingClause.Clause.(type) { + case *cypherModel.Create: + return fmt.Errorf("create unsupported") + + case *cypherModel.Delete: + if err := s.rewriteDeleteClause(singlePartQuery, typedClause); err != nil { + return err + } + + case *cypherModel.Set: + for _, setItem := range typedClause.Items { + switch leftHandOperand := setItem.Left.(type) { + case *cypherModel.Variable: + switch rightHandOperand := setItem.Right.(type) { + case graph.Kinds: + s.TrackKindAddition(leftHandOperand.Symbol, rightHandOperand...) + + default: + return fmt.Errorf("unexpected right side operand type %T for kind setter", setItem.Right) + } + + case *cypherModel.PropertyLookup: + switch setItem.Operator { + case cypherModel.OperatorAssignment: + var ( + // TODO: Type negotiation + referenceSymbol = leftHandOperand.Atom.(*cypherModel.Variable).Symbol + propertyName = leftHandOperand.Symbols[0] + ) + + switch rightHandOperand := setItem.Right.(type) { + case *cypherModel.Literal: + // TODO: Negotiate null literals + s.TrackPropertyAddition(referenceSymbol, propertyName, rightHandOperand.Value) + + case *pg.AnnotatedLiteral: + s.TrackPropertyAddition(referenceSymbol, propertyName, rightHandOperand.Value) + + case *cypherModel.Parameter: + s.TrackPropertyAddition(referenceSymbol, propertyName, rightHandOperand.Value) + + case *pg.AnnotatedParameter: + s.TrackPropertyAddition(referenceSymbol, propertyName, rightHandOperand.Value) + + default: + return fmt.Errorf("unexpected right side operand type %T for property setter", setItem.Right) + } + + default: + return fmt.Errorf("unsupported assignment operator: %s", setItem.Operator) + } + } + } + + case *cypherModel.Remove: + for _, removeItem := range typedClause.Items { + if removeItem.KindMatcher != nil { + if kindMatcher, typeOK := removeItem.KindMatcher.(*cypherModel.KindMatcher); !typeOK { + return fmt.Errorf("unexpected remove item kind matcher expression: %T", removeItem.KindMatcher) + } else if kindMatcherReference, typeOK := kindMatcher.Reference.(*cypherModel.Variable); !typeOK { + return fmt.Errorf("unexpected remove matcher reference expression: %T", kindMatcher.Reference) + } else { + s.TrackKindRemoval(kindMatcherReference.Symbol, kindMatcher.Kinds...) + } + } + + if removeItem.Property != nil { + var ( + // TODO: Type negotiation + referenceSymbol = removeItem.Property.Atom.(*cypherModel.Variable).Symbol + propertyName = removeItem.Property.Symbols[0] + ) + + s.TrackPropertyRemoval(referenceSymbol, propertyName) + } + } + } + } + + if updatingClauses, err := s.ToUpdatingClause(); err != nil { + return err + } else { + singlePartQuery.UpdatingClauses = updatingClauses + } + + return nil +} + +func (s *UpdatingClauseRewriter) HasAdditions() bool { + return len(s.propertyAdditions) > 0 || len(s.kindAdditions) > 0 +} + +func (s *UpdatingClauseRewriter) HasRemovals() bool { + return len(s.propertyRemovals) > 0 || len(s.kindRemovals) > 0 +} + +func (s *UpdatingClauseRewriter) HasChanges() bool { + return s.HasAdditions() || s.HasRemovals() +} + +func (s *UpdatingClauseRewriter) TrackKindAddition(referenceSymbol string, kinds ...graph.Kind) { + s.kindReferenceSymbols[referenceSymbol] = struct{}{} + + if existingAdditions, hasAdditions := s.kindAdditions[referenceSymbol]; hasAdditions { + s.kindAdditions[referenceSymbol] = append(existingAdditions, kinds...) + } else { + s.kindAdditions[referenceSymbol] = kinds + } +} + +func (s *UpdatingClauseRewriter) TrackKindRemoval(referenceSymbol string, kinds ...graph.Kind) { + s.kindReferenceSymbols[referenceSymbol] = struct{}{} + + if existingRemovals, hasRemovals := s.kindRemovals[referenceSymbol]; hasRemovals { + s.kindRemovals[referenceSymbol] = append(existingRemovals, kinds...) + } else { + s.kindRemovals[referenceSymbol] = kinds + } +} + +func (s *UpdatingClauseRewriter) TrackPropertyAddition(referenceSymbol, propertyName string, value any) { + s.propertyReferenceSymbols[referenceSymbol] = struct{}{} + + if existingAdditions, hasAdditions := s.propertyAdditions[referenceSymbol]; hasAdditions { + existingAdditions[propertyName] = value + } else { + s.propertyAdditions[referenceSymbol] = map[string]any{ + propertyName: value, + } + } +} + +func (s *UpdatingClauseRewriter) TrackPropertyRemoval(referenceSymbol, propertyName string) { + s.propertyReferenceSymbols[referenceSymbol] = struct{}{} + + if existingRemovals, hasRemovals := s.propertyRemovals[referenceSymbol]; hasRemovals { + s.propertyRemovals[referenceSymbol] = append(existingRemovals, propertyName) + } else { + s.propertyRemovals[referenceSymbol] = []string{propertyName} + } +} diff --git a/packages/go/cypher/backend/pgsql/pgtransition/shortestpaths.go b/packages/go/cypher/backend/pgsql/pgtransition/shortestpaths.go new file mode 100644 index 0000000000..66722f3d4f --- /dev/null +++ b/packages/go/cypher/backend/pgsql/pgtransition/shortestpaths.go @@ -0,0 +1,279 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgtransition + +import ( + "bytes" + "fmt" + "github.com/specterops/bloodhound/cypher/analyzer" + "github.com/specterops/bloodhound/cypher/backend/pgsql" + "github.com/specterops/bloodhound/cypher/model" + "github.com/specterops/bloodhound/cypher/model/pg" + "github.com/specterops/bloodhound/dawgs/query" +) + +type AllShortestPathsArguments struct { + RootCriteria string + TraversalCriteria string + TerminalCriteria string + MaxDepth int +} + +func RewriteParameters(regularQuery *model.RegularQuery) error { + return analyzer.Analyze(regularQuery, func(analyzerInst *analyzer.Analyzer) { + analyzer.WithVisitor(analyzerInst, func(stack *model.WalkStack, node *model.Parameter) error { + parameterValue := node.Value + + switch typedParameterValue := parameterValue.(type) { + case string: + // The cypher AST model expects strings to be contained within a single quote wrapper + parameterValue = "'" + typedParameterValue + "'" + } + + switch typedTrunk := stack.Trunk().(type) { + case model.ExpressionList: + typedTrunk.Replace(typedTrunk.IndexOf(node), query.Literal(parameterValue)) + + case *model.PartialComparison: + typedTrunk.Right = query.Literal(parameterValue) + } + + return nil + }) + }) +} + +func RemoveEmptyExpressionLists(stack *model.WalkStack, element model.Expression) error { + var ( + shouldRemove = false + shouldReplace = false + + replacementExpression model.Expression + ) + + switch typedElement := element.(type) { + case model.ExpressionList: + shouldRemove = typedElement.Len() == 0 + + case *model.Parenthetical: + switch typedParentheticalElement := typedElement.Expression.(type) { + case model.ExpressionList: + numExpressions := typedParentheticalElement.Len() + + shouldRemove = numExpressions == 0 + shouldReplace = numExpressions == 1 + + if shouldReplace { + // Dump the parenthetical and the joined expression by grabbing the only element in the joined + // expression for replacement + replacementExpression = typedParentheticalElement.Get(0) + } + } + } + + if shouldRemove { + switch typedParent := stack.Trunk().(type) { + case model.ExpressionList: + typedParent.Remove(element) + } + } else if shouldReplace { + switch typedParent := stack.Trunk().(type) { + case model.ExpressionList: + typedParent.Replace(typedParent.IndexOf(element), replacementExpression) + } + } + + return nil +} + +type Ripper struct { + targetVariableSymbol string +} + +func (s *Ripper) Enter(stack *model.WalkStack, expression model.Expression) error { + if expressionList, isExpressionList := stack.Trunk().(model.ExpressionList); isExpressionList { + switch typedExpression := expression.(type) { + case *model.KindMatcher: + // Look for constraints + if variable, typeOK := typedExpression.Reference.(*model.Variable); !typeOK { + return fmt.Errorf("expected variable in all shortests paths kind matcher but saw: %T", typedExpression.Reference) + } else if variable.Symbol != s.targetVariableSymbol { + // Rip this expression since it's a comparison that targets a variable we don't care about + expressionList.Remove(expression) + } else { + switch s.targetVariableSymbol { + case query.EdgeStartSymbol, query.EdgeEndSymbol: + expressionList.Replace(expressionList.IndexOf(expression), pg.NewAnnotatedKindMatcher(typedExpression, pg.Node)) + + case query.EdgeSymbol: + expressionList.Replace(expressionList.IndexOf(expression), pg.NewAnnotatedKindMatcher(typedExpression, pg.Edge)) + + default: + return fmt.Errorf("unsupported variable symbol: %s", s.targetVariableSymbol) + } + } + + case *model.Comparison: + var leftHandNode = typedExpression.Left + + // Unwrap function invocations that may wrap the left hand expression + switch typedNode := leftHandNode.(type) { + case *model.Variable: + case *model.PropertyLookup: + leftHandNode = typedNode.Atom + + case *model.FunctionInvocation: + if typedNode.Name == model.IdentityFunction { + // Validate the length of the arguments for sanity checking + if len(typedNode.Arguments) != 1 { + return fmt.Errorf("expected only 1 argument") + } + + // If this is an ID lookup of the variable pull the variable reference out of it + leftHandNode = typedNode.Arguments[0] + } + + default: + return fmt.Errorf("unexpected left hand comparison expression: %T", leftHandNode) + } + + // Look for constraints + if variable, typeOK := leftHandNode.(*model.Variable); !typeOK { + return fmt.Errorf("expected *pgsql.AnnotatedVariable in all shortests paths comparison but saw: %T", leftHandNode) + } else if variable.Symbol != s.targetVariableSymbol { + // Rip this expression since it's a comparison that targets a variable we don't care about + expressionList.Remove(expression) + } + } + } + + return nil +} + +func (s *Ripper) Exit(stack *model.WalkStack, expression model.Expression) error { + return nil +} + +func TranslateAllShortestPaths(regularQuery *model.RegularQuery, kindMapper pgsql.KindMapper) (AllShortestPathsArguments, error) { + aspArguments := AllShortestPathsArguments{ + MaxDepth: 12, + } + + if regularQuery.SingleQuery.MultiPartQuery != nil { + return aspArguments, fmt.Errorf("multi-part queries not supported") + } + + if numReadingClauses := len(regularQuery.SingleQuery.SinglePartQuery.ReadingClauses); numReadingClauses != 1 { + return aspArguments, fmt.Errorf("expected one reading clause but saw %d", numReadingClauses) + } + + if err := RewriteParameters(regularQuery); err != nil { + return aspArguments, err + } + + readingClause := regularQuery.SingleQuery.SinglePartQuery.ReadingClauses[0] + + if readingClause.Match == nil || readingClause.Match.Where == nil { + return aspArguments, fmt.Errorf("no match or where clause specified") + } + + if len(readingClause.Match.Where.Expressions) != 1 { + return aspArguments, fmt.Errorf("expected where clause to have only one top-level and expression") + } + + if topLevelConjunction, typeOK := readingClause.Match.Where.Expressions[0].(*model.Conjunction); !typeOK { + return aspArguments, fmt.Errorf("expected where clause to have only one top-level and expression") + } else { + var ( + rootNodeCopy = model.Copy(topLevelConjunction) + edgeCopy = model.Copy(topLevelConjunction) + terminalNodeCopy = model.Copy(topLevelConjunction) + ) + + if err := model.Walk(rootNodeCopy, &Ripper{ + targetVariableSymbol: query.EdgeStartSymbol, + }); err != nil { + return aspArguments, err + } + + if err := model.Walk(edgeCopy, &Ripper{ + targetVariableSymbol: query.EdgeSymbol, + }); err != nil { + return aspArguments, err + } + + if err := model.Walk(terminalNodeCopy, &Ripper{ + targetVariableSymbol: query.EdgeEndSymbol, + }); err != nil { + return aspArguments, err + } + + buffer := &bytes.Buffer{} + emitter := pgsql.NewEmitter(false, kindMapper) + + if len(rootNodeCopy.Expressions) == 0 { + return aspArguments, fmt.Errorf("expected start node constraints but found none") + } else { + if err := analyzer.Analyze(rootNodeCopy, func(analyzerInst *analyzer.Analyzer) { + analyzer.WithVisitor(analyzerInst, RemoveEmptyExpressionLists) + }, pg.CollectPGSQLTypes); err != nil { + return aspArguments, err + } + + if err := emitter.WriteExpression(buffer, rootNodeCopy); err != nil { + return aspArguments, err + } else { + aspArguments.RootCriteria = buffer.String() + buffer.Reset() + } + } + + if len(edgeCopy.Expressions) > 0 { + if err := analyzer.Analyze(edgeCopy, func(analyzerInst *analyzer.Analyzer) { + analyzer.WithVisitor(analyzerInst, RemoveEmptyExpressionLists) + }, pg.CollectPGSQLTypes); err != nil { + return aspArguments, err + } + + if err := emitter.WriteExpression(buffer, edgeCopy); err != nil { + return aspArguments, err + } + + aspArguments.TraversalCriteria = buffer.String() + buffer.Reset() + } + + if len(terminalNodeCopy.Expressions) == 0 { + return aspArguments, fmt.Errorf("expected start node constraints but found none") + } else { + if err := analyzer.Analyze(terminalNodeCopy, func(analyzerInst *analyzer.Analyzer) { + analyzer.WithVisitor(analyzerInst, RemoveEmptyExpressionLists) + }, pg.CollectPGSQLTypes); err != nil { + return aspArguments, err + } + + if err := emitter.WriteExpression(buffer, terminalNodeCopy); err != nil { + return aspArguments, err + } else { + aspArguments.TerminalCriteria = buffer.String() + buffer.Reset() + } + } + } + + return aspArguments, nil +} diff --git a/packages/go/cypher/backend/pgsql/pgtransition/shortestpaths_test.go b/packages/go/cypher/backend/pgsql/pgtransition/shortestpaths_test.go new file mode 100644 index 0000000000..c26a41f160 --- /dev/null +++ b/packages/go/cypher/backend/pgsql/pgtransition/shortestpaths_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgtransition_test + +import ( + "github.com/specterops/bloodhound/cypher/backend/pgsql/pgtransition" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/dawgs/query" + "github.com/specterops/bloodhound/graphschema/ad" + "github.com/specterops/bloodhound/src/test" + "github.com/stretchr/testify/require" + "testing" +) + +type kindMapper struct{} + +func (k kindMapper) MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { + return make([]int16, len(kinds)), nil +} + +func TestTranslateAllShortestPaths(t *testing.T) { + builder := query.NewBuilder(&query.Cache{}) + builder.Apply(query.Where( + query.And( + query.And(query.Equals(query.StartID(), graph.ID(1)), query.Equals(query.EndProperty("name"), "1")), + query.KindIn(query.Relationship(), ad.PublishedTo, ad.IssuedSignedBy, ad.EnterpriseCAFor, ad.RootCAFor), + query.Equals(query.EndID(), graph.ID(5)), + ), + )) + + aspArguments, err := pgtransition.TranslateAllShortestPaths(builder.RegularQuery(), kindMapper{}) + test.RequireNilErr(t, err) + + require.Equal(t, "s.id = 1", aspArguments.RootCriteria, "Root Criteria") + require.Equal(t, "(r.kind_id = any(array[0]::int2[]) or r.kind_id = any(array[0]::int2[]) or r.kind_id = any(array[0]::int2[]) or r.kind_id = any(array[0]::int2[]))", aspArguments.TraversalCriteria, "Traversal Criteria") + require.Equal(t, "e.properties->'name' = '1' and e.id = 5", aspArguments.TerminalCriteria, "Terminal Criteria") +} diff --git a/packages/go/cypher/backend/pgsql/rewrite.go b/packages/go/cypher/backend/pgsql/rewrite.go new file mode 100644 index 0000000000..6470abd6f5 --- /dev/null +++ b/packages/go/cypher/backend/pgsql/rewrite.go @@ -0,0 +1,86 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql + +import ( + "fmt" + "github.com/specterops/bloodhound/cypher/model" +) + +func rewrite(stack *model.WalkStack, original, rewritten model.Expression) error { + switch typedTrunk := stack.Trunk().(type) { + case model.ExpressionList: + for idx, expression := range typedTrunk.GetAll() { + if expression == original { + typedTrunk.Replace(idx, rewritten) + } + } + + case *model.FunctionInvocation: + for idx, expression := range typedTrunk.Arguments { + if expression == original { + typedTrunk.Arguments[idx] = rewritten + } + } + + case *model.ProjectionItem: + typedTrunk.Expression = rewritten + + case *model.SetItem: + if typedTrunk.Right == original { + typedTrunk.Right = rewritten + } else if typedTrunk.Left == original { + typedTrunk.Left = rewritten + } else { + return fmt.Errorf("unable to match original expression against SetItem left and right operands") + } + + case *model.PartialComparison: + typedTrunk.Right = rewritten + + case *model.RemoveItem: + switch typedRewritten := rewritten.(type) { + case *model.KindMatcher: + typedTrunk.KindMatcher = typedRewritten + } + + case *model.Projection: + for idx, projectionItem := range typedTrunk.Items { + if projectionItem == original { + typedTrunk.Items[idx] = rewritten + } + } + + case *model.Negation: + typedTrunk.Expression = rewritten + + case *model.Comparison: + if typedTrunk.Left == original { + typedTrunk.Left = rewritten + } + + case *model.Parenthetical: + if typedTrunk.Expression == original { + typedTrunk.Expression = rewritten + } + + default: + return fmt.Errorf("unable to replace expression for trunk type %T", stack.Trunk()) + } + + return nil +} diff --git a/packages/go/cypher/backend/pgsql/translation.go b/packages/go/cypher/backend/pgsql/translation.go new file mode 100644 index 0000000000..4623e446e2 --- /dev/null +++ b/packages/go/cypher/backend/pgsql/translation.go @@ -0,0 +1,1101 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql + +import ( + "fmt" + "github.com/jackc/pgtype" + "github.com/specterops/bloodhound/cypher/analyzer" + "github.com/specterops/bloodhound/cypher/model" + "github.com/specterops/bloodhound/cypher/model/pg" + "strconv" + "strings" +) + +func GetSymbol(expression model.Expression) (string, error) { + switch typedExpression := expression.(type) { + case *model.PatternElement: + if nodePattern, isNodePattern := typedExpression.AsNodePattern(); isNodePattern { + if nodePattern.Binding != nil { + return GetSymbol(nodePattern.Binding) + } + } else if relationshipPattern, isRelationshipPattern := typedExpression.AsRelationshipPattern(); isRelationshipPattern { + if relationshipPattern.Binding != nil { + return GetSymbol(relationshipPattern.Binding) + } + } + + case *model.PatternPart: + if typedExpression.Binding != nil { + return GetSymbol(typedExpression.Binding) + } + + case *model.Variable: + return typedExpression.Symbol, nil + + case *pg.AnnotatedVariable: + return typedExpression.Symbol, nil + + default: + return "", fmt.Errorf("unable to source symbol from expression type %T", expression) + } + + return "", nil +} + +type Binder struct { + parameters map[string]*pg.AnnotatedParameter + bindingTypeMappings map[string]pg.DataType + aliases map[string]string + patternBindings map[string]struct{} + syntheticBindings map[string]struct{} + nextParameterID int + nextBindingID int +} + +func NewBinder() *Binder { + return &Binder{ + parameters: map[string]*pg.AnnotatedParameter{}, + bindingTypeMappings: map[string]pg.DataType{}, + aliases: map[string]string{}, + patternBindings: map[string]struct{}{}, + syntheticBindings: map[string]struct{}{}, + nextParameterID: 0, + nextBindingID: 0, + } +} + +func (s *Binder) Parameters() map[string]any { + parametersCopy := make(map[string]any, len(s.parameters)) + + for _, parameter := range s.parameters { + parametersCopy[parameter.Symbol] = parameter.Value + } + + return parametersCopy +} + +func (s *Binder) BindVariable(variable *model.Variable, bindingType pg.DataType) *pg.AnnotatedVariable { + s.bindingTypeMappings[variable.Symbol] = bindingType + return pg.NewAnnotatedVariable(variable, bindingType) +} + +func (s *Binder) BindPatternVariable(variable *model.Variable, bindingType pg.DataType) *pg.AnnotatedVariable { + s.patternBindings[variable.Symbol] = struct{}{} + return s.BindVariable(variable, bindingType) +} + +func (s *Binder) BindingType(binding string) (pg.DataType, bool) { + if bindingType, isBound := s.bindingTypeMappings[binding]; isBound { + return bindingType, isBound + } + + return pg.UnknownDataType, false +} + +func (s *Binder) LookupVariable(symbol string) (*pg.AnnotatedVariable, bool) { + if dataType, isBound := s.BindingType(symbol); isBound { + return pg.NewAnnotatedVariable(model.NewVariableWithSymbol(symbol), dataType), true + } + + return nil, false +} + +func (s *Binder) IsSynthetic(binding string) bool { + _, isSynthetic := s.syntheticBindings[binding] + return isSynthetic +} + +func (s *Binder) IsPatternBinding(binding string) bool { + _, isPatternBinding := s.patternBindings[binding] + return isPatternBinding +} + +func (s *Binder) IsBound(binding string) bool { + _, isBound := s.bindingTypeMappings[binding] + return isBound +} + +func (s *Binder) NewBinding(prefix string) string { + // Spin to win + for { + binding := prefix + strconv.Itoa(s.nextBindingID) + s.nextBindingID++ + + if !s.IsBound(binding) { + s.syntheticBindings[binding] = struct{}{} + return binding + } + } +} + +func (s *Binder) NewAnnotatedVariable(prefix string, bindingType pg.DataType) *pg.AnnotatedVariable { + return s.BindVariable(s.NewVariable(prefix), bindingType) +} + +func (s *Binder) NewVariable(prefix string) *model.Variable { + return model.NewVariableWithSymbol(s.NewBinding(prefix)) +} + +func (s *Binder) NewParameterSymbol() string { + nextParameterSymbol := "p" + strconv.Itoa(s.nextParameterID) + s.nextParameterID++ + + return nextParameterSymbol +} + +func (s *Binder) NewParameter(value any) (*pg.AnnotatedParameter, error) { + var ( + parameterSymbol = s.NewParameterSymbol() + ) + + if parameterTypeAnnotation, err := pg.NewSQLTypeAnnotationFromValue(value); err != nil { + return nil, err + } else { + parameter := pg.NewAnnotatedParameter(model.NewParameter(parameterSymbol, value), parameterTypeAnnotation.Type) + + // Record the parameter's value for mapping to the query later + s.parameters[parameterSymbol] = parameter + return parameter, nil + } +} + +func (s *Binder) NewLiteral(literal *model.Literal) (*pg.AnnotatedLiteral, error) { + if literalTypeAnnotation, err := pg.NewSQLTypeAnnotationFromValue(literal.Value); err != nil { + return nil, err + } else { + return pg.NewAnnotatedLiteral(literal, literalTypeAnnotation.Type), nil + } +} + +func (s *Binder) NewAlias(originalSymbol string, alias *model.Variable) *pg.AnnotatedVariable { + s.aliases[originalSymbol] = alias.Symbol + + if originalBindingType, isBound := s.bindingTypeMappings[originalSymbol]; isBound { + return s.BindVariable(alias, originalBindingType) + } + + return s.BindVariable(alias, pg.UnknownDataType) +} + +func (s *Binder) Scan(regularQuery *model.RegularQuery) error { + if err := analyzer.Analyze(regularQuery, func(analyzerInst *analyzer.Analyzer) { + analyzer.WithVisitor(analyzerInst, func(stack *model.WalkStack, node *model.Parameter) error { + // Rewrite all parameter symbols and collect their values + if annotatedParameter, err := s.NewParameter(node.Value); err != nil { + return err + } else { + return rewrite(stack, node, annotatedParameter) + } + }) + + analyzer.WithVisitor(analyzerInst, func(stack *model.WalkStack, node *model.Literal) error { + // Rewrite all parameter symbols and collect their values + if annotatedLiteral, err := s.NewLiteral(node); err != nil { + return err + } else { + return rewrite(stack, node, annotatedLiteral) + } + }) + + analyzer.WithVisitor(analyzerInst, func(stack *model.WalkStack, patternPart *model.PatternPart) error { + if patternPart.Binding != nil { + if bindingVariable, typeOK := patternPart.Binding.(*model.Variable); !typeOK { + return fmt.Errorf("expected variable for pattern part binding but got: %T", patternPart.Binding) + } else { + patternPart.Binding = s.BindPatternVariable(bindingVariable, pg.Path) + } + } + return nil + }) + + analyzer.WithVisitor(analyzerInst, func(stack *model.WalkStack, patternElement *model.PatternElement) error { + // Eagerly bind all ReadingClause pattern elements to simplify referencing when crafting SQL join statements + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + if nodePattern.Binding == nil { + nodePattern.Binding = s.NewAnnotatedVariable("n", pg.Node) + } else if bindingVariable, typeOK := nodePattern.Binding.(*model.Variable); !typeOK { + return fmt.Errorf("expected variable for node pattern binding but got: %T", nodePattern.Binding) + } else if _, isPatternPredicate := stack.Trunk().(*model.PatternPredicate); isPatternPredicate { + nodePattern.Binding = s.BindVariable(bindingVariable, pg.Node) + } else { + nodePattern.Binding = s.BindPatternVariable(bindingVariable, pg.Node) + } + } else { + relationshipPattern, _ := patternElement.AsRelationshipPattern() + + if relationshipPattern.Binding == nil { + relationshipPattern.Binding = s.NewAnnotatedVariable("e", pg.Edge) + } else if bindingVariable, typeOK := relationshipPattern.Binding.(*model.Variable); !typeOK { + return fmt.Errorf("expected variable for relationship pattern binding but got: %T", relationshipPattern.Binding) + } else if _, isPatternPredicate := stack.Trunk().(*model.PatternPredicate); isPatternPredicate { + relationshipPattern.Binding = s.BindVariable(bindingVariable, pg.Edge) + } else { + relationshipPattern.Binding = s.BindPatternVariable(bindingVariable, pg.Edge) + } + } + + return nil + }) + + analyzer.WithVisitor(analyzerInst, func(_ *model.WalkStack, node *model.ProjectionItem) error { + if bindingVariable, isVariable := node.Binding.(*model.Variable); node.Binding != nil && isVariable { + if projectionVariable, isVariable := node.Expression.(*model.Variable); isVariable { + node.Binding = s.NewAlias(projectionVariable.Symbol, bindingVariable) + } + } + + return nil + }) + + analyzer.WithVisitor(analyzerInst, func(_ *model.WalkStack, node *model.Delete) error { + for idx, expression := range node.Expressions { + switch typedExpression := expression.(type) { + case *model.Variable: + if annotatedVariable, isAnnotated := s.LookupVariable(typedExpression.Symbol); !isAnnotated { + return fmt.Errorf("unable to look up type annotation for variable reference: %s", typedExpression.Symbol) + } else { + node.Expressions[idx] = annotatedVariable + } + } + } + + return nil + }) + }, pg.CollectPGSQLTypes); err != nil { + return err + } + + return nil +} + +type Translator struct { + builder *strings.Builder + Bindings *Binder + kindMapper KindMapper + regularQuery *model.RegularQuery +} + +func NewTranslator(kindMapper KindMapper, bindings *Binder, regularQuery *model.RegularQuery) *Translator { + return &Translator{ + builder: &strings.Builder{}, + kindMapper: kindMapper, + Bindings: bindings, + regularQuery: regularQuery, + } +} + +func (s *Translator) rewriteUpdatingClauses(_ *model.WalkStack, singlePartQuery *model.SinglePartQuery) error { + return NewUpdateClauseRewriter(s.Bindings, s.kindMapper).RewriteUpdatingClauses(singlePartQuery) +} + +func (s *Translator) liftNodePatternCriteria(_ *model.WalkStack, nodePattern *model.NodePattern) ([]model.Expression, error) { + var criteria []model.Expression + + if nodePattern.Binding == nil { + nodePattern.Binding = s.Bindings.NewVariable("n") + } + + if len(nodePattern.Kinds) > 0 { + kindMatcher := model.NewKindMatcher(nodePattern.Binding, nodePattern.Kinds) + criteria = append(criteria, pg.NewAnnotatedKindMatcher(kindMatcher, pg.Node)) + } + + if nodePattern.Properties != nil { + nodePropertyMatchers := nodePattern.Properties.(*model.Properties) + + if nodePropertyMatchers.Parameter != nil { + return nil, fmt.Errorf("unable to translate property matcher paramter for node %s", nodePattern.Binding) + } + + for propertyName, matcherValue := range nodePropertyMatchers.Map { + if bindingVariable, typeOK := nodePattern.Binding.(*pg.AnnotatedVariable); !typeOK { + return nil, fmt.Errorf("unexpected node pattern binding type for node pattern: %T", nodePattern.Binding) + } else { + propertyLookup := model.NewPropertyLookup(bindingVariable.Symbol, propertyName) + + if annotation, err := pg.NewSQLTypeAnnotationFromExpression(matcherValue); err != nil { + return nil, err + } else { + criteria = append(criteria, model.NewComparison( + pg.NewAnnotatedPropertyLookup(propertyLookup, annotation.Type), + model.OperatorEquals, + matcherValue, + )) + } + } + } + } + + return criteria, nil +} + +func (s *Translator) liftRelationshipPatternCriteria(_ *model.WalkStack, relationshipPattern *model.RelationshipPattern) ([]model.Expression, error) { + var criteria []model.Expression + + if relationshipPattern.Binding == nil { + relationshipPattern.Binding = s.Bindings.NewVariable("e") + } + + if len(relationshipPattern.Kinds) > 0 { + kindMatcher := model.NewKindMatcher(relationshipPattern.Binding, relationshipPattern.Kinds) + criteria = append(criteria, pg.NewAnnotatedKindMatcher(kindMatcher, pg.Edge)) + } + + if relationshipPattern.Properties != nil { + edgePropertyMatchers := relationshipPattern.Properties.(*model.Properties) + + if edgePropertyMatchers.Parameter != nil { + return nil, fmt.Errorf("unable to translate property matcher paramter for edge %s", relationshipPattern.Binding) + } + + for propertyName, matcherValue := range edgePropertyMatchers.Map { + if bindingVariable, typeOK := relationshipPattern.Binding.(*pg.AnnotatedVariable); !typeOK { + return nil, fmt.Errorf("unexpected relationship pattern binding type: %T", relationshipPattern.Binding) + } else { + propertyLookup := model.NewPropertyLookup(bindingVariable.Symbol, propertyName) + + if annotation, err := pg.NewSQLTypeAnnotationFromExpression(matcherValue); err != nil { + return nil, err + } else { + criteria = append(criteria, model.NewComparison( + pg.NewAnnotatedPropertyLookup(propertyLookup, annotation.Type), + model.OperatorEquals, + matcherValue, + )) + } + } + } + } + + return criteria, nil +} + +func (s *Translator) liftPatternElementCriteria(stack *model.WalkStack, patternElement *model.PatternElement) ([]model.Expression, error) { + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + return s.liftNodePatternCriteria(stack, nodePattern) + } + + relationshipPattern, _ := patternElement.AsRelationshipPattern() + return s.liftRelationshipPatternCriteria(stack, relationshipPattern) +} + +func (s *Translator) translatePatternPredicates(stack *model.WalkStack, patternPredicate *model.PatternPredicate) error { + var ( + subqueryFilters []model.Expression + subquery = &pg.Subquery{ + PatternElements: patternPredicate.PatternElements, + } + ) + + for _, patternElement := range subquery.PatternElements { + if nodePattern, isNodePattern := patternElement.AsNodePattern(); isNodePattern { + // Is the node pattern bound to a variable and was that variable bound earlier in the AST? + if bindingVariable, typeOK := nodePattern.Binding.(*pg.AnnotatedVariable); !typeOK { + return fmt.Errorf("unexpected node pattern binding type for pattern predicate: %T", nodePattern.Binding) + } else if nodePattern.Binding != nil && !s.Bindings.IsSynthetic(bindingVariable.Symbol) && s.Bindings.IsPatternBinding(bindingVariable.Symbol) { + // Since this pattern element is bound to a pre-existing referenced pattern element we have to match + // against it by its identity + var ( + oldBinding = nodePattern.Binding + newBinding = s.Bindings.NewAnnotatedVariable("n", bindingVariable.Type) + ) + + nodePattern.Binding = newBinding + subqueryFilters = append(subqueryFilters, model.NewComparison( + model.NewSimpleFunctionInvocation( + cypherIdentityFunction, + oldBinding, + ), + model.OperatorEquals, + model.NewSimpleFunctionInvocation( + cypherIdentityFunction, + newBinding, + ), + )) + } + + if criteria, err := s.liftNodePatternCriteria(stack, nodePattern); err != nil { + return err + } else { + subqueryFilters = append(subqueryFilters, criteria...) + } + } else { + relationshipPattern, _ := patternElement.AsRelationshipPattern() + + // Is the relationship pattern bound to a variable and was that variable bound earlier in the AST? + if bindingVariable, typeOK := relationshipPattern.Binding.(*pg.AnnotatedVariable); !typeOK { + return fmt.Errorf("unexpected relationship pattern binding type: %T", relationshipPattern.Binding) + } else if relationshipPattern.Binding != nil && !s.Bindings.IsSynthetic(bindingVariable.Symbol) && s.Bindings.IsPatternBinding(bindingVariable.Symbol) { + // Since this pattern element is bound to a pre-existing referenced pattern element we have to match + // against it by its identity + var ( + oldBinding = relationshipPattern.Binding + newBinding = s.Bindings.NewAnnotatedVariable("e", bindingVariable.Type) + ) + + relationshipPattern.Binding = newBinding + subqueryFilters = append(subqueryFilters, model.NewComparison( + model.NewSimpleFunctionInvocation( + cypherIdentityFunction, + oldBinding, + ), + model.OperatorEquals, + model.NewSimpleFunctionInvocation( + cypherIdentityFunction, + newBinding, + ), + )) + } + + if criteria, err := s.liftRelationshipPatternCriteria(stack, relationshipPattern); err != nil { + return err + } else { + subqueryFilters = append(subqueryFilters, criteria...) + } + } + + } + + if len(subqueryFilters) > 0 { + subquery.Filter = model.NewConjunction(subqueryFilters...) + + return rewrite(stack, patternPredicate, subquery) + } + + return nil +} + +func (s *Translator) liftMatchCriteria(stack *model.WalkStack, match *model.Match) error { + var additionalCriteria []model.Expression + + for _, patternPart := range match.Pattern { + for _, patternElement := range patternPart.PatternElements { + if patternElementCriteria, err := s.liftPatternElementCriteria(stack, patternElement); err != nil { + return err + } else { + additionalCriteria = append(additionalCriteria, patternElementCriteria...) + } + } + } + + if len(additionalCriteria) > 0 { + if match.Where == nil { + match.Where = model.NewWhere() + } + + match.Where.Expressions = []model.Expression{ + model.NewConjunction(append(additionalCriteria, match.Where.Expressions...)...), + } + } + + return nil +} + +func (s *Translator) annotateKindMatchers(stack *model.WalkStack, kindMatcher *model.KindMatcher) error { + switch typedExpression := kindMatcher.Reference.(type) { + case *pg.AnnotatedVariable: + return rewrite(stack, kindMatcher, pg.NewAnnotatedKindMatcher(kindMatcher, typedExpression.Type)) + + case *model.Variable: + if dataType, hasBindingType := s.Bindings.BindingType(typedExpression.Symbol); !hasBindingType { + return fmt.Errorf("unable to locate a binding type for variable %s", typedExpression.Symbol) + } else { + return rewrite(stack, kindMatcher, pg.NewAnnotatedKindMatcher(kindMatcher, dataType)) + } + + default: + return fmt.Errorf("unexpected kind matcher reference type %T", kindMatcher.Reference) + } +} + +func (s *Translator) rewriteComparison(stack *model.WalkStack, comparison *model.Comparison) (bool, error) { + // Is this a property lookup comparison? + switch typedLeftOperand := comparison.Left.(type) { + case *model.PropertyLookup: + // Try to suss out if this is a property existence check + if len(comparison.Partials) == 1 { + comparisonPartial := comparison.Partials[0] + + switch typedRightHand := comparisonPartial.Right.(type) { + case *pg.AnnotatedLiteral: + if typedRightHand.Null { + // This is a null check for a property and must be rewritten for SQL + switch comparisonPartial.Operator { + case model.OperatorIsNot: + if leftOperandVariable, isVariable := typedLeftOperand.Atom.(*model.Variable); !isVariable { + return false, fmt.Errorf("unexpected expression as left operand %T", typedLeftOperand.Atom) + } else if leftOperandTypedVariable, isBound := s.Bindings.LookupVariable(leftOperandVariable.Symbol); !isBound { + return false, fmt.Errorf("left operand varaible %s is not bound", leftOperandTypedVariable.Symbol) + } else if err := rewrite(stack, comparison, model.NewComparison( + &pg.PropertiesReference{ + // TODO: Might need a copy? + Reference: leftOperandTypedVariable, + }, + OperatorJSONBFieldExists, + pg.NewStringLiteral(typedLeftOperand.Symbols[0]), + )); err != nil { + return false, err + } + + case model.OperatorIs: + if leftOperandVariable, isVariable := typedLeftOperand.Atom.(*model.Variable); !isVariable { + return false, fmt.Errorf("unexpected expression as left operand %T", typedLeftOperand.Atom) + } else if leftOperandTypedVariable, isBound := s.Bindings.LookupVariable(leftOperandVariable.Symbol); !isBound { + return false, fmt.Errorf("left operand varaible %s is not bound", leftOperandTypedVariable.Symbol) + } else if err := rewrite(stack, comparison, model.NewNegation( + model.NewComparison( + &pg.PropertiesReference{ + Reference: leftOperandTypedVariable, + }, + OperatorJSONBFieldExists, + pg.NewStringLiteral(typedLeftOperand.Symbols[0]), + )), + ); err != nil { + return false, err + } + } + + return true, nil + } + } + } + } + + return false, nil +} + +func (s *Translator) rewritePartialComparison(_ *model.WalkStack, partial *model.PartialComparison) error { + switch partial.Operator { + case model.OperatorIn: + switch partial.Right.(type) { + case *model.Parameter, *pg.AnnotatedParameter: + // When the "in" operator addresses right-hand parameter it must be rewritten as: "= any($param)" + partial.Operator = model.OperatorEquals + partial.Right = model.NewSimpleFunctionInvocation(pgsqlAnyFunction, partial.Right) + } + + case model.OperatorStartsWith: + // Replace this operator with the like operator + partial.Operator = OperatorLike + + // If the right side isn't a string for any of these it's an error + switch typedRightOperand := partial.Right.(type) { + case *pg.AnnotatedLiteral: + if stringValue, isString := typedRightOperand.Value.(string); !isString { + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } else { + // Strip the wrapping single quotes first + s.builder.Reset() + s.builder.WriteString("'") + s.builder.WriteString(stringValue[1 : len(stringValue)-1]) + s.builder.WriteString("%'") + + typedRightOperand.Value = s.builder.String() + } + + case *pg.AnnotatedParameter: + if stringValue, isString := typedRightOperand.Value.(string); !isString { + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } else { + // Parameters are raw values and have no quotes + s.builder.Reset() + s.builder.WriteString(stringValue) + s.builder.WriteString("%") + + typedRightOperand.Value = s.builder.String() + } + + default: + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } + + case model.OperatorContains: + // Replace this operator with the like operator + partial.Operator = OperatorLike + + // If the right side isn't a string for any of these it's an error + switch typedRightOperand := partial.Right.(type) { + case *pg.AnnotatedLiteral: + if stringValue, isString := typedRightOperand.Value.(string); !isString { + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } else { + // Strip the wrapping single quotes first + s.builder.Reset() + s.builder.WriteString("'%") + s.builder.WriteString(stringValue[1 : len(stringValue)-1]) + s.builder.WriteString("%'") + + typedRightOperand.Value = s.builder.String() + } + + case *pg.AnnotatedParameter: + if stringValue, isString := typedRightOperand.Value.(string); !isString { + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } else { + // Parameters are raw values and have no quotes + s.builder.Reset() + s.builder.WriteString("%") + s.builder.WriteString(stringValue) + s.builder.WriteString("%") + + typedRightOperand.Value = s.builder.String() + } + + default: + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } + + case model.OperatorEndsWith: + // Replace this operator with the like operator + partial.Operator = OperatorLike + + // If the right side isn't a string for any of these it's an error + switch typedRightOperand := partial.Right.(type) { + case *pg.AnnotatedLiteral: + if stringValue, isString := typedRightOperand.Value.(string); !isString { + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } else { + // Strip the wrapping single quotes first + s.builder.Reset() + s.builder.WriteString("'%") + s.builder.WriteString(stringValue[1 : len(stringValue)-1]) + s.builder.WriteString("'") + + typedRightOperand.Value = s.builder.String() + } + + case *pg.AnnotatedParameter: + if stringValue, isString := typedRightOperand.Value.(string); !isString { + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } else { + // Parameters are raw values and have no quotes + s.builder.Reset() + s.builder.WriteString("%") + s.builder.WriteString(stringValue) + + typedRightOperand.Value = s.builder.String() + } + + default: + return fmt.Errorf("string operator \"%s\" expects a string literal or parameter as its right opperand", partial.Operator.String()) + } + + case model.OperatorEquals: + switch typedRightOperand := partial.Right.(type) { + case *pg.AnnotatedLiteral: + // If this is an array type then first wrap it in the `to_jsonb` function + if typedRightOperand.Type.IsArrayType() { + partial.Right = model.NewSimpleFunctionInvocation(pgsqlToJSONBFunction, partial.Right) + } + + case *pg.AnnotatedParameter: + // If this is an array type then rewrite it as a JSONB value + if typedRightOperand.Type.IsArrayType() { + newParameter := &pgtype.JSONB{} + + if err := newParameter.Set(typedRightOperand.Value); err != nil { + return err + } + + typedRightOperand.Value = newParameter + } + } + } + + return nil +} + +func (s *Translator) annotateComparisons(stack *model.WalkStack, comparison *model.Comparison) error { + var ( + typeAnnotation *pg.SQLTypeAnnotation + operator model.Operator + ) + + if rewritten, err := s.rewriteComparison(stack, comparison); err != nil { + return err + } else if rewritten { + return nil + } + + for comparisonWalkStack := []model.Expression{comparison}; len(comparisonWalkStack) > 0; { + next := comparisonWalkStack[len(comparisonWalkStack)-1] + comparisonWalkStack = comparisonWalkStack[:len(comparisonWalkStack)-1] + + switch typedNode := next.(type) { + case *model.Comparison: + comparisonWalkStack = append(comparisonWalkStack, typedNode.Left) + + for _, partial := range typedNode.Partials { + comparisonWalkStack = append(comparisonWalkStack, partial) + } + + case *model.PartialComparison: + // TODO: Overloading the operator means that we may miss partial comparison continuations + operator = typedNode.Operator + + if err := s.rewritePartialComparison(stack, typedNode); err != nil { + return err + } + + comparisonWalkStack = append(comparisonWalkStack, typedNode.Right) + + case *pg.AnnotatedParameter: + if typeAnnotation == nil { + typeAnnotation = &pg.SQLTypeAnnotation{ + Type: typedNode.Type, + } + } else if typeAnnotation.Type != typedNode.Type { + return fmt.Errorf("comparison contains mixed types: %s and %s", typeAnnotation.Type, typedNode.Type) + } + + case *pg.AnnotatedLiteral: + if typeAnnotation == nil { + typeAnnotation = &pg.SQLTypeAnnotation{ + Type: typedNode.Type, + } + } else if typeAnnotation.Type != typedNode.Type { + return fmt.Errorf("comparison contains mixed types: %s and %s", typeAnnotation.Type, typedNode.Type) + } + + case *model.FunctionInvocation: + var functionInvocationTypeAnnotation *pg.SQLTypeAnnotation + + switch typedNode.Name { + case cypherDateFunction: + functionInvocationTypeAnnotation = &pg.SQLTypeAnnotation{ + Type: pg.Date, + } + + case cypherTimeFunction: + functionInvocationTypeAnnotation = &pg.SQLTypeAnnotation{ + Type: pg.TimeWithTimeZone, + } + + case cypherLocalTimeFunction: + functionInvocationTypeAnnotation = &pg.SQLTypeAnnotation{ + Type: pg.TimeWithoutTimeZone, + } + + case cypherDateTimeFunction: + functionInvocationTypeAnnotation = &pg.SQLTypeAnnotation{ + Type: pg.TimestampWithTimeZone, + } + + case cypherLocalDateTimeFunction: + functionInvocationTypeAnnotation = &pg.SQLTypeAnnotation{ + Type: pg.TimestampWithoutTimeZone, + } + + case cypherDurationFunction: + functionInvocationTypeAnnotation = &pg.SQLTypeAnnotation{ + Type: pg.Interval, + } + + default: + // If we couldn't figure out a type from the function name then inspect the function's argument list + comparisonWalkStack = append(comparisonWalkStack, typedNode.Arguments...) + } + + // If there was a function invocation type, check to validate that we're not producing mixed type + // annotations for the comparison + if functionInvocationTypeAnnotation != nil { + if typeAnnotation == nil { + typeAnnotation = functionInvocationTypeAnnotation + } else if typeAnnotation.Type != functionInvocationTypeAnnotation.Type { + return fmt.Errorf("comparison contains mixed types: %s and %s", typeAnnotation.Type, functionInvocationTypeAnnotation.Type) + } + } + } + } + + if typeAnnotation != nil { + if leftHandPropertyLookup, typeOK := comparison.Left.(*model.PropertyLookup); typeOK { + leftOperandType := typeAnnotation.Type + + // If this is an array type we need to do some special rewriting negotiation for different operators + if typeAnnotation.Type.IsArrayType() { + switch operator { + case model.OperatorIn: + // If this is an operation such that in then we must wrap the right hand + // operand using the pgsql any() function and type the left hand operand to the array's base type + if baseType, err := typeAnnotation.Type.ArrayBaseType(); err != nil { + return err + } else { + leftOperandType = baseType + } + + default: + // If this isn't a contains operator then rewrite the left hand operand type to jsonb and wrap + // the right hand operand in the to_json function with the type annotation of jsonb + leftOperandType = pg.JSONB + } + } + + // Rewrite the left operand so that the property lookup is correctly type annotated + comparison.Left = pg.NewAnnotatedPropertyLookup(leftHandPropertyLookup, leftOperandType) + + for _, partialComparison := range comparison.Partials { + switch typedPartialComparison := partialComparison.Right.(type) { + case *model.PropertyLookup: + // Make sure right hand operand property lookups are correctly type annotated + annotatedPropertyLookup := pg.NewAnnotatedPropertyLookup(typedPartialComparison, typeAnnotation.Type) + + if err := rewrite(stack, partialComparison.Right, annotatedPropertyLookup); err != nil { + return err + } + } + } + } + } + + return nil +} + +func (s *Translator) rewriteNegations(_ *model.WalkStack, negation *model.Negation) error { + // Wrap negations that contain a list of expressions in a parenthetical expression to ensure that evaluation + // happens as intended by the author of the query + if _, isExpressionList := negation.Expression.(model.ExpressionList); isExpressionList { + negation.Expression = model.NewParenthetical(negation.Expression) + } + + return nil +} + +func (s *Translator) rewriteStringNegations(stack *model.WalkStack, negation *model.Negation) error { + var rewritten any + + // If this is a negation then we should check to see if it's a comparison + switch comparison := negation.Expression.(type) { + case *model.Comparison: + firstPartial := comparison.FirstPartial() + + // If the negated expression is a comparison check to see if it's a string comparison. This is done since + // comparison semantics for strings regarding `null` has edge cases that must be accounted for + switch firstPartial.Operator { + case model.OperatorStartsWith, model.OperatorEndsWith, model.OperatorContains: + // Rewrite this comparison is a disjunction of the negation and a follow-on comparison to handle null + // checks + rewritten = &model.Parenthetical{ + Expression: model.NewDisjunction( + negation, + model.NewComparison(comparison.Left, model.OperatorIs, pg.NewAnnotatedLiteral(model.NewLiteral(nil, true), pg.Null)), + ), + } + } + } + + // If we rewrote this element, replace it + if rewritten != nil { + switch typedParent := stack.Trunk().(type) { + case model.ExpressionList: + for idx, expression := range typedParent.GetAll() { + if expression == negation { + typedParent.Replace(idx, rewritten) + break + } + } + + default: + return fmt.Errorf("unable to replace rewritten string negation operation for parent type %T", stack.Trunk()) + } + } + + return nil +} + +func (s *Translator) rewriteFunctionInvocations(stack *model.WalkStack, functionInvocation *model.FunctionInvocation) error { + switch functionInvocation.Name { + case cypherNodeLabelsFunction: + switch typedArgument := functionInvocation.Arguments[0].(type) { + case *model.Variable: + return rewrite(stack, functionInvocation, pg.NewNodeKindsReference(pg.NewAnnotatedVariable(typedArgument, pg.Node))) + + case *pg.AnnotatedVariable: + return rewrite(stack, functionInvocation, pg.NewNodeKindsReference(typedArgument)) + + default: + return fmt.Errorf("expected a variable as the first argument in %s function", functionInvocation.Name) + } + + case cypherEdgeTypeFunction: + switch typedArgument := functionInvocation.Arguments[0].(type) { + case *model.Variable: + return rewrite(stack, functionInvocation, pg.NewEdgeKindReference(pg.NewAnnotatedVariable(typedArgument, pg.Edge))) + + case *pg.AnnotatedVariable: + return rewrite(stack, functionInvocation, pg.NewEdgeKindReference(typedArgument)) + + default: + return fmt.Errorf("expected a variable as the first argument in %s function", functionInvocation.Name) + } + + case cypherToLowerFunction: + switch typedArgument := functionInvocation.Arguments[0].(type) { + case *model.PropertyLookup: + functionInvocation.Arguments[0] = pg.NewAnnotatedPropertyLookup(typedArgument, pg.Text) + } + } + + return nil +} + +func (s *Translator) annotateProjectionItems(_ *model.WalkStack, projectionItem *model.ProjectionItem) error { + switch typedExpression := projectionItem.Expression.(type) { + case *model.Variable: + if bindingType, isBound := s.Bindings.BindingType(typedExpression.Symbol); !isBound { + return fmt.Errorf("variable %s for projection item is not bound", typedExpression.Symbol) + } else { + projectionItem.Expression = pg.NewEntity(pg.NewAnnotatedVariable(typedExpression, bindingType)) + + // Set projection item binding to the variable reference if there's no binding present + if projectionItem.Binding == nil { + projectionItem.Binding = pg.NewAnnotatedVariable(typedExpression, bindingType) + } + } + } + + return nil +} + +func (s *Translator) validatePropertyLookups(_ *model.WalkStack, propertyLookup *model.PropertyLookup) error { + if len(propertyLookup.Symbols) != 1 { + return fmt.Errorf("expected a single-depth propertly lookup") + } + + return nil +} +func (s *Translator) removeEmptyExpressionLists(stack *model.WalkStack, element model.Expression) error { + var ( + shouldRemove = false + shouldReplace = false + + replacementExpression model.Expression + ) + + switch typedElement := element.(type) { + case model.ExpressionList: + shouldRemove = typedElement.Len() == 0 + + case *model.Parenthetical: + switch typedParentheticalElement := typedElement.Expression.(type) { + case model.ExpressionList: + numExpressions := typedParentheticalElement.Len() + + shouldRemove = numExpressions == 0 + shouldReplace = numExpressions == 1 + + if shouldReplace { + // Dump the parenthetical and the joined expression by grabbing the only element in the joined + // expression for replacement + replacementExpression = typedParentheticalElement.Get(0) + } + } + } + + if shouldRemove { + switch typedParent := stack.Trunk().(type) { + case model.ExpressionList: + typedParent.Remove(element) + } + } else if shouldReplace { + switch typedParent := stack.Trunk().(type) { + case model.ExpressionList: + typedParent.Replace(typedParent.IndexOf(element), replacementExpression) + } + } + + return nil +} + +func (s *Translator) rewriteKindFilters(stack *model.WalkStack, disjunction *model.Disjunction) error { + var ( + kindsByRef = map[string]*pg.AnnotatedKindMatcher{} + nonKindMatcherExpressions []model.Expression + ) + + for _, expression := range disjunction.GetAll() { + switch typedExpression := expression.(type) { + case *pg.AnnotatedKindMatcher: + if binding, err := GetSymbol(typedExpression.Reference); err != nil { + return err + } else if kindMatcher, hasMatcher := kindsByRef[binding]; hasMatcher { + kindMatcher.Kinds = append(kindMatcher.Kinds, typedExpression.Kinds...) + } else { + kindsByRef[binding] = pg.Copy(typedExpression) + } + + default: + nonKindMatcherExpressions = append(nonKindMatcherExpressions, typedExpression) + } + } + + kindMatchers := make([]model.Expression, 0, len(kindsByRef)) + + for _, kindMatcher := range kindsByRef { + kindMatchers = append(kindMatchers, kindMatcher) + } + + if len(nonKindMatcherExpressions) == 0 { + if len(kindMatchers) == 1 { + return rewrite(stack, disjunction, kindMatchers[0]) + } else { + return rewrite(stack, disjunction, model.NewDisjunction(kindMatchers...)) + } + } else if len(kindMatchers) > 0 { + return rewrite(stack, disjunction, model.NewDisjunction(append(nonKindMatcherExpressions, kindMatchers...)...)) + } + + return nil +} + +func Translate(regularQuery *model.RegularQuery, kindMapper KindMapper) (map[string]any, error) { + var ( + bindings = NewBinder() + rewriter = NewTranslator(kindMapper, bindings, regularQuery) + ) + + if err := bindings.Scan(regularQuery); err != nil { + return nil, err + } + + // Rewrite phase + if err := analyzer.Analyze(regularQuery, func(analyzerInst *analyzer.Analyzer) { + analyzer.WithVisitor(analyzerInst, rewriter.rewriteStringNegations) + analyzer.WithVisitor(analyzerInst, rewriter.annotateProjectionItems) + analyzer.WithVisitor(analyzerInst, rewriter.validatePropertyLookups) + analyzer.WithVisitor(analyzerInst, rewriter.annotateKindMatchers) + analyzer.WithVisitor(analyzerInst, rewriter.liftMatchCriteria) + analyzer.WithVisitor(analyzerInst, rewriter.annotateComparisons) + analyzer.WithVisitor(analyzerInst, rewriter.translatePatternPredicates) + analyzer.WithVisitor(analyzerInst, rewriter.rewriteFunctionInvocations) + analyzer.WithVisitor(analyzerInst, rewriter.rewriteUpdatingClauses) + }, pg.CollectPGSQLTypes); err != nil { + return nil, err + } + + // Optimization phase + if err := analyzer.Analyze(regularQuery, func(analyzerInst *analyzer.Analyzer) { + analyzer.WithVisitor(analyzerInst, rewriter.rewriteNegations) + analyzer.WithVisitor(analyzerInst, rewriter.rewriteKindFilters) + analyzer.WithVisitor(analyzerInst, rewriter.removeEmptyExpressionLists) + }, pg.CollectPGSQLTypes); err != nil { + return nil, err + } + + return bindings.Parameters(), nil +} diff --git a/packages/go/cypher/backend/pgsql/type.go b/packages/go/cypher/backend/pgsql/type.go new file mode 100644 index 0000000000..f43f5b0410 --- /dev/null +++ b/packages/go/cypher/backend/pgsql/type.go @@ -0,0 +1,95 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql + +import ( + "bytes" + "encoding/json" + "github.com/jackc/pgtype" + "github.com/specterops/bloodhound/dawgs/graph" +) + +func ValueToJSONB(value any) (pgtype.JSONB, error) { + var jsonbArgument pgtype.JSONB + + return jsonbArgument, jsonbArgument.Set(value) +} + +func Int32SliceToInt4Array(value []int32) (pgtype.Int4Array, error) { + var pgInt4Array pgtype.Int4Array + + return pgInt4Array, pgInt4Array.Set(value) +} + +func IDSliceToInt8Array(value []graph.ID) (pgtype.Int8Array, error) { + var pgInt8Array pgtype.Int8Array + + return pgInt8Array, pgInt8Array.Set(value) +} + +func StringSliceToTextArray(values []string) (pgtype.TextArray, error) { + var pgTextArray pgtype.TextArray + return pgTextArray, pgTextArray.Set(values) +} + +func MapStringAnyToJSONB(value map[string]any) (pgtype.JSONB, error) { + var jsonb pgtype.JSONB + + return jsonb, jsonb.Set(value) +} + +func PropertiesToJSONB(properties *graph.Properties) (pgtype.JSONB, error) { + return MapStringAnyToJSONB(properties.MapOrEmpty()) +} + +func JSONBToProperties(jsonb pgtype.JSONB) (*graph.Properties, error) { + propertiesMap := make(map[string]any) + + if err := jsonb.AssignTo(&propertiesMap); err != nil { + return nil, err + } + + return graph.AsProperties(propertiesMap), nil +} + +func MatcherAsJSONB(fieldName string, value any) (pgtype.JSONB, error) { + var ( + matcher = bytes.Buffer{} + jsonbMatcher = pgtype.JSONB{} + ) + + // Prepare the JSONB matcher + if marshalledValue, err := json.Marshal(value); err != nil { + return jsonbMatcher, err + } else { + matcher.WriteString(`{"`) + matcher.WriteString(fieldName) + matcher.WriteString(`":`) + matcher.Write(marshalledValue) + matcher.WriteString(`}`) + } + + return ValueToJSONB(matcher.Bytes()) +} + +func MustMatcherAsJSONB(fieldName string, value any) pgtype.JSONB { + if jsonbMatcher, err := MatcherAsJSONB(fieldName, value); err != nil { + panic(err) + } else { + return jsonbMatcher + } +} diff --git a/packages/go/cypher/backend/pgsql/util.go b/packages/go/cypher/backend/pgsql/util.go new file mode 100644 index 0000000000..bcb09a601a --- /dev/null +++ b/packages/go/cypher/backend/pgsql/util.go @@ -0,0 +1,65 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pgsql + +import ( + "io" + "strconv" + "strings" +) + +func JoinUint[T uint | uint8 | uint16 | uint32 | uint64](values []T, separator string) string { + builder := strings.Builder{} + + for idx := 0; idx < len(values); idx++ { + if idx > 0 { + builder.WriteString(separator) + } + + builder.WriteString(strconv.FormatUint(uint64(values[idx]), 10)) + } + + return builder.String() +} + +func JoinInt[T int | int8 | int16 | int32 | int64](values []T, separator string) string { + builder := strings.Builder{} + + for idx := 0; idx < len(values); idx++ { + if idx > 0 { + builder.WriteString(separator) + } + + builder.WriteString(strconv.FormatInt(int64(values[idx]), 10)) + } + + return builder.String() +} + +func WriteStrings(writer io.Writer, strings ...string) (int, error) { + totalBytesWritten := 0 + + for idx := 0; idx < len(strings); idx++ { + if bytesWritten, err := io.WriteString(writer, strings[idx]); err != nil { + return totalBytesWritten, err + } else { + totalBytesWritten += bytesWritten + } + } + + return totalBytesWritten, nil +} diff --git a/packages/go/cypher/frontend/atom.go b/packages/go/cypher/frontend/atom.go index 2d217325f6..5ea5e42936 100644 --- a/packages/go/cypher/frontend/atom.go +++ b/packages/go/cypher/frontend/atom.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package frontend @@ -153,11 +153,11 @@ func (s *AtomVisitor) ExitOC_Expression(ctx *parser.OC_ExpressionContext) { } func (s *AtomVisitor) EnterOC_PatternPredicate(ctx *parser.OC_PatternPredicateContext) { - s.ctx.Enter(&PatternVisitor{}) + s.ctx.Enter(NewPatternPredicateVisitor()) } func (s *AtomVisitor) ExitOC_PatternPredicate(ctx *parser.OC_PatternPredicateContext) { - s.Atom = s.ctx.Exit().(*PatternVisitor).PatternParts + s.Atom = s.ctx.Exit().(*PatternPredicateVisitor).PatternPredicate } func (s *AtomVisitor) EnterOC_Quantifier(ctx *parser.OC_QuantifierContext) { diff --git a/packages/go/cypher/frontend/context.go b/packages/go/cypher/frontend/context.go index 89e6108a44..7fa3031742 100644 --- a/packages/go/cypher/frontend/context.go +++ b/packages/go/cypher/frontend/context.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package frontend @@ -29,7 +29,7 @@ type descentEntry struct { depth int } -// Context satifies the antlr.ParseTreeListener interface needed for antlr's tree walker. +// Context satisfies the antlr.ParseTreeListener interface needed for antlr's tree walker. type Context struct { visitorStack []*descentEntry filters []Visitor diff --git a/packages/go/cypher/frontend/parse.go b/packages/go/cypher/frontend/parse.go index 7e5604ee10..7b29a287f4 100644 --- a/packages/go/cypher/frontend/parse.go +++ b/packages/go/cypher/frontend/parse.go @@ -32,25 +32,6 @@ func DefaultCypherContext() *Context { ) } -func CypherToCypher(ctx *Context, input string) (string, error) { - if query, err := ParseCypher(ctx, input); err != nil { - return "", err - } else { - var ( - output = &bytes.Buffer{} - emitter = CypherEmitter{ - StripLiterals: false, - } - ) - - if err := emitter.Write(query, output); err != nil { - return "", err - } - - return output.String(), nil - } -} - func parseCypher(ctx *Context, input string) (*model.RegularQuery, error) { var ( queryBuffer = bytes.NewBufferString(input) @@ -60,7 +41,7 @@ func parseCypher(ctx *Context, input string) (*model.RegularQuery, error) { parseTreeWalker = antlr.NewParseTreeWalker() queryVisitor = &QueryVisitor{} ) - + // Set up the lexer and parser to report errors to the context lexer.RemoveErrorListeners() lexer.AddErrorListener(ctx) diff --git a/packages/go/cypher/frontend/pattern.go b/packages/go/cypher/frontend/pattern.go index 91698f2b72..06a47710e6 100644 --- a/packages/go/cypher/frontend/pattern.go +++ b/packages/go/cypher/frontend/pattern.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package frontend @@ -49,11 +49,11 @@ type NodePatternVisitor struct { } func (s *NodePatternVisitor) EnterOC_Variable(ctx *parser.OC_VariableContext) { - s.ctx.Enter(&SymbolicNameOrReservedWordVisitor{}) + s.ctx.Enter(NewVariableVisitor()) } func (s *NodePatternVisitor) ExitOC_Variable(ctx *parser.OC_VariableContext) { - s.NodePattern.Binding = s.ctx.Exit().(*SymbolicNameOrReservedWordVisitor).Name + s.NodePattern.Binding = s.ctx.Exit().(*VariableVisitor).Variable } func (s *NodePatternVisitor) EnterOC_NodeLabels(ctx *parser.OC_NodeLabelsContext) { @@ -119,11 +119,11 @@ func (s *RelationshipPatternVisitor) ExitOC_RelTypeName(ctx *parser.OC_RelTypeNa } func (s *RelationshipPatternVisitor) EnterOC_Variable(ctx *parser.OC_VariableContext) { - s.ctx.Enter(&SymbolicNameOrReservedWordVisitor{}) + s.ctx.Enter(NewVariableVisitor()) } func (s *RelationshipPatternVisitor) ExitOC_Variable(ctx *parser.OC_VariableContext) { - s.RelationshipPattern.Binding = s.ctx.Exit().(*SymbolicNameOrReservedWordVisitor).Name + s.RelationshipPattern.Binding = s.ctx.Exit().(*VariableVisitor).Variable } func (s *RelationshipPatternVisitor) EnterOC_LeftArrowHead(ctx *parser.OC_LeftArrowHeadContext) { @@ -162,6 +162,40 @@ func (s *RelationshipPatternVisitor) ExitOC_Properties(ctx *parser.OC_Properties s.RelationshipPattern.Properties = s.ctx.Exit().(*PropertiesVisitor).Properties } +type PatternPredicateVisitor struct { + BaseVisitor + + PatternPredicate *model.PatternPredicate +} + +func NewPatternPredicateVisitor() *PatternPredicateVisitor { + return &PatternPredicateVisitor{ + PatternPredicate: model.NewPatternPredicate(), + } +} + +func (s *PatternPredicateVisitor) EnterOC_NodePattern(ctx *parser.OC_NodePatternContext) { + s.ctx.Enter(&NodePatternVisitor{ + NodePattern: &model.NodePattern{}, + }) +} + +func (s *PatternPredicateVisitor) ExitOC_NodePattern(ctx *parser.OC_NodePatternContext) { + s.PatternPredicate.AddElement(s.ctx.Exit().(*NodePatternVisitor).NodePattern) +} + +func (s *PatternPredicateVisitor) EnterOC_RelationshipPattern(ctx *parser.OC_RelationshipPatternContext) { + s.ctx.Enter(&RelationshipPatternVisitor{ + RelationshipPattern: &model.RelationshipPattern{ + Direction: graph.DirectionBoth, + }, + }) +} + +func (s *PatternPredicateVisitor) ExitOC_RelationshipPattern(ctx *parser.OC_RelationshipPatternContext) { + s.PatternPredicate.AddElement(s.ctx.Exit().(*RelationshipPatternVisitor).RelationshipPattern) +} + type PatternVisitor struct { BaseVisitor @@ -215,11 +249,11 @@ func (s *PatternVisitor) ExitOC_ShortestPathPattern(ctx *parser.OC_ShortestPathP } func (s *PatternVisitor) EnterOC_Variable(ctx *parser.OC_VariableContext) { - s.ctx.Enter(&SymbolicNameOrReservedWordVisitor{}) + s.ctx.Enter(NewVariableVisitor()) } func (s *PatternVisitor) ExitOC_Variable(ctx *parser.OC_VariableContext) { - s.currentPart.Binding = s.ctx.Exit().(*SymbolicNameOrReservedWordVisitor).Name + s.currentPart.Binding = s.ctx.Exit().(*VariableVisitor).Variable } func (s *PatternVisitor) EnterOC_NodePattern(ctx *parser.OC_NodePatternContext) { diff --git a/packages/go/cypher/frontend/query.go b/packages/go/cypher/frontend/query.go index 4e2b0b31b6..ae9ac8ce3a 100644 --- a/packages/go/cypher/frontend/query.go +++ b/packages/go/cypher/frontend/query.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package frontend @@ -500,7 +500,7 @@ func (s *RemoveVisitor) EnterOC_NodeLabels(ctx *parser.OC_NodeLabelsContext) { } func (s *RemoveVisitor) ExitOC_NodeLabels(ctx *parser.OC_NodeLabelsContext) { - s.currentItem.KindMatcher.Kinds = s.ctx.Exit().(*NodeLabelsVisitor).Kinds + s.currentItem.KindMatcher.(*model.KindMatcher).Kinds = s.ctx.Exit().(*NodeLabelsVisitor).Kinds } func (s *RemoveVisitor) EnterOC_Variable(ctx *parser.OC_VariableContext) { @@ -510,7 +510,7 @@ func (s *RemoveVisitor) EnterOC_Variable(ctx *parser.OC_VariableContext) { } func (s *RemoveVisitor) ExitOC_Variable(ctx *parser.OC_VariableContext) { - s.currentItem.KindMatcher.Reference = s.ctx.Exit().(*VariableVisitor).Variable + s.currentItem.KindMatcher.(*model.KindMatcher).Reference = s.ctx.Exit().(*VariableVisitor).Variable } func (s *RemoveVisitor) EnterOC_PropertyExpression(ctx *parser.OC_PropertyExpressionContext) { diff --git a/packages/go/cypher/go.mod b/packages/go/cypher/go.mod index 9ff7642908..54231f317d 100644 --- a/packages/go/cypher/go.mod +++ b/packages/go/cypher/go.mod @@ -1,34 +1,40 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 module github.com/specterops/bloodhound/cypher -go 1.20 +go 1.21 require ( github.com/antlr4-go/antlr/v4 v4.13.0 + github.com/jackc/pgtype v1.14.0 github.com/stretchr/testify v1.8.4 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgx/v4 v4.18.1 // indirect github.com/kr/pretty v0.3.1 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect + golang.org/x/crypto v0.10.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect + golang.org/x/text v0.10.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/packages/go/cypher/go.sum b/packages/go/cypher/go.sum index f4b4473d7c..7b2e94c14d 100644 --- a/packages/go/cypher/go.sum +++ b/packages/go/cypher/go.sum @@ -1,12 +1,25 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= +github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/packages/go/cypher/model/copy.go b/packages/go/cypher/model/copy.go index bcef19e976..097e9fe6d8 100644 --- a/packages/go/cypher/model/copy.go +++ b/packages/go/cypher/model/copy.go @@ -1,24 +1,23 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package model import ( "fmt" - "github.com/specterops/bloodhound/dawgs/graph" ) @@ -36,7 +35,9 @@ func copySlice[T any, S []T](slice S) S { return valueCopy } -func Copy[T any](value T) T { +type CopyExtension[T any] func(value T) (T, bool) + +func Copy[T any](value T, extensions ...CopyExtension[T]) T { var empty T switch typedValue := any(value).(type) { @@ -148,7 +149,7 @@ func Copy[T any](value T) T { case *ExclusiveDisjunction: return any(typedValue.copy()).(T) - case JoiningExpression: + case expressionList: return any(typedValue.copy()).(T) case *PatternPart: @@ -175,6 +176,9 @@ func Copy[T any](value T) T { case *PatternRange: return any(typedValue.copy()).(T) + case *PatternPredicate: + return any(typedValue.copy()).(T) + case *PatternElement: return any(typedValue.copy()).(T) @@ -245,6 +249,12 @@ func Copy[T any](value T) T { return empty default: + for _, extension := range extensions { + if valueCopy, handled := extension(value); handled { + return valueCopy + } + } + panic(fmt.Sprintf("unable to copy type %T", value)) } } diff --git a/packages/go/cypher/model/copy_test.go b/packages/go/cypher/model/copy_test.go index cd5f7ff57a..7d644bec50 100644 --- a/packages/go/cypher/model/copy_test.go +++ b/packages/go/cypher/model/copy_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package model_test @@ -19,9 +19,9 @@ package model_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/cypher/model" "github.com/specterops/bloodhound/dawgs/graph" + "github.com/stretchr/testify/require" ) func validateCopy(t *testing.T, actual any) { @@ -54,7 +54,7 @@ func TestCopy(t *testing.T) { Match: &model.Match{ Optional: true, Pattern: []*model.PatternPart{{ - Binding: "p", + Binding: model.NewVariableWithSymbol("p"), ShortestPathPattern: true, AllShortestPathsPattern: true, PatternElements: []*model.PatternElement{}, @@ -136,17 +136,17 @@ func TestCopy(t *testing.T) { validateCopy(t, &model.Disjunction{}) validateCopy(t, &model.ExclusiveDisjunction{}) validateCopy(t, &model.PatternPart{ - Binding: "p", + Binding: model.NewVariableWithSymbol("p"), ShortestPathPattern: true, AllShortestPathsPattern: true, }) validateCopy(t, &model.PatternElement{}) validateCopy(t, &model.Negation{}) validateCopy(t, &model.NodePattern{ - Binding: "n", + Binding: model.NewVariableWithSymbol("n"), }) validateCopy(t, &model.RelationshipPattern{ - Binding: "r", + Binding: model.NewVariableWithSymbol("r"), Direction: graph.DirectionOutbound, }) validateCopy(t, &model.PatternRange{ @@ -157,8 +157,9 @@ func TestCopy(t *testing.T) { Ascending: true, }) validateCopy(t, []*model.PatternPart{}) - - validateCopy(t, model.JoiningExpression{}) + validateCopy(t, &model.PatternPredicate{ + PatternElements: []*model.PatternElement{{}}, + }) // External types validateCopy(t, []string{}) diff --git a/packages/go/cypher/model/functions.go b/packages/go/cypher/model/functions.go new file mode 100644 index 0000000000..874d310c88 --- /dev/null +++ b/packages/go/cypher/model/functions.go @@ -0,0 +1,31 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package model + +const ( + CountFunction = "count" + DateFunction = "date" + TimeFunction = "time" + LocalTimeFunction = "localtime" + DateTimeFunction = "datetime" + LocalDateTimeFunction = "localdatetime" + DurationFunction = "duration" + IdentityFunction = "id" + ToLowerFunction = "toLower" + NodeLabelsFunction = "labels" + EdgeTypeFunction = "type" +) diff --git a/packages/go/cypher/model/model.go b/packages/go/cypher/model/model.go index 9b6445e0fb..bef5a6f156 100644 --- a/packages/go/cypher/model/model.go +++ b/packages/go/cypher/model/model.go @@ -1,26 +1,25 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package model import ( + "github.com/specterops/bloodhound/dawgs/graph" "sort" "strings" - - "github.com/specterops/bloodhound/dawgs/graph" ) type SortOrder string @@ -45,6 +44,7 @@ type Expression any type ExpressionList interface { Add(expression Expression) + AddSlice(expressions []Expression) Get(index int) Expression GetAll() []Expression Len() int @@ -53,17 +53,27 @@ type ExpressionList interface { Replace(index int, expression Expression) } -type JoiningExpression struct { +type expressionList struct { Expressions []Expression } -func (s *JoiningExpression) copy() JoiningExpression { - return JoiningExpression{ +func NewExpressionListFromSlice(slice []Expression) ExpressionList { + return &expressionList{ + Expressions: slice, + } +} + +func NewExpressionList() ExpressionList { + return &expressionList{} +} + +func (s *expressionList) copy() expressionList { + return expressionList{ Expressions: Copy(s.Expressions), } } -func (s *JoiningExpression) IndexOf(expressionToFind Expression) int { +func (s *expressionList) IndexOf(expressionToFind Expression) int { for idx, expression := range s.Expressions { if expression == expressionToFind { return idx @@ -73,11 +83,11 @@ func (s *JoiningExpression) IndexOf(expressionToFind Expression) int { return -1 } -func (s *JoiningExpression) Len() int { +func (s *expressionList) Len() int { return len(s.Expressions) } -func (s *JoiningExpression) Remove(expressionToRemove Expression) bool { +func (s *expressionList) Remove(expressionToRemove Expression) bool { for idx, expression := range s.Expressions { if expression == expressionToRemove { s.Expressions = append(s.Expressions[:idx], s.Expressions[idx+1:]...) @@ -88,19 +98,23 @@ func (s *JoiningExpression) Remove(expressionToRemove Expression) bool { return false } -func (s *JoiningExpression) Add(expression Expression) { +func (s *expressionList) Add(expression Expression) { s.Expressions = append(s.Expressions, expression) } -func (s *JoiningExpression) Get(index int) Expression { +func (s *expressionList) AddSlice(expressions []Expression) { + s.Expressions = append(s.Expressions, expressions...) +} + +func (s *expressionList) Get(index int) Expression { return s.Expressions[index] } -func (s *JoiningExpression) GetAll() []Expression { +func (s *expressionList) GetAll() []Expression { return s.Expressions } -func (s *JoiningExpression) Replace(index int, expression Expression) { +func (s *expressionList) Replace(index int, expression Expression) { s.Expressions[index] = expression } @@ -283,7 +297,7 @@ type SinglePartQuery struct { errorContext ReadingClauses []*ReadingClause - UpdatingClauses []*UpdatingClause + UpdatingClauses []Expression Return *Return } @@ -456,7 +470,7 @@ func (s *Remove) copy() *Remove { } type RemoveItem struct { - KindMatcher *KindMatcher + KindMatcher Expression Property *PropertyLookup } @@ -654,6 +668,10 @@ func NewLiteral(value any, null bool) *Literal { } } +func NewStringLiteral(value string) *Literal { + return NewLiteral("'"+value+"'", false) +} + func (s *Literal) copy() *Literal { return &Literal{ Value: s.Value, @@ -780,6 +798,12 @@ type Parenthetical struct { Expression Expression } +func NewParenthetical(expression Expression) *Parenthetical { + return &Parenthetical{ + Expression: expression, + } +} + func (s *Parenthetical) copy() *Parenthetical { return &Parenthetical{ Expression: Copy(s.Expression), @@ -787,12 +811,12 @@ func (s *Parenthetical) copy() *Parenthetical { } type ExclusiveDisjunction struct { - JoiningExpression + expressionList } func NewExclusiveDisjunction(expressions ...Expression) *ExclusiveDisjunction { return &ExclusiveDisjunction{ - JoiningExpression{ + expressionList{ Expressions: expressions, }, } @@ -804,17 +828,17 @@ func (s *ExclusiveDisjunction) copy() *ExclusiveDisjunction { } return &ExclusiveDisjunction{ - JoiningExpression: Copy(s.JoiningExpression), + expressionList: Copy(s.expressionList), } } type Disjunction struct { - JoiningExpression + expressionList } func NewDisjunction(expressions ...Expression) *Disjunction { return &Disjunction{ - JoiningExpression: JoiningExpression{ + expressionList: expressionList{ Expressions: expressions, }, } @@ -826,17 +850,17 @@ func (s *Disjunction) copy() *Disjunction { } return &Disjunction{ - JoiningExpression: Copy(s.JoiningExpression), + expressionList: Copy(s.expressionList), } } type Conjunction struct { - JoiningExpression + expressionList } func NewConjunction(expressions ...Expression) *Conjunction { return &Conjunction{ - JoiningExpression{ + expressionList{ Expressions: expressions, }, } @@ -848,7 +872,7 @@ func (s *Conjunction) copy() *Conjunction { } return &Conjunction{ - JoiningExpression: Copy(s.JoiningExpression), + expressionList: Copy(s.expressionList), } } @@ -973,7 +997,7 @@ func (s *Variable) copy() *Variable { type ProjectionItem struct { Expression Expression - Binding *Variable + Binding Expression } func NewProjectionItem() *ProjectionItem { @@ -1083,7 +1107,7 @@ func (s *Properties) copy() *Properties { // NodePattern type NodePattern struct { - Binding string + Binding Expression Kinds graph.Kinds Properties Expression } @@ -1111,7 +1135,7 @@ func (s *NodePattern) AddKind(kind graph.Kind) { // RelationshipPattern type RelationshipPattern struct { - Binding string + Binding Expression Kinds graph.Kinds Direction graph.Direction Range *PatternRange @@ -1137,7 +1161,7 @@ func (s *RelationshipPattern) AddKind(kind graph.Kind) { } type Where struct { - JoiningExpression + expressionList } func NewWhere() *Where { @@ -1150,7 +1174,7 @@ func (s *Where) copy() *Where { } return &Where{ - JoiningExpression: Copy(s.JoiningExpression), + expressionList: Copy(s.expressionList), } } @@ -1194,7 +1218,7 @@ type Projection struct { Order *Order Skip *Skip Limit *Limit - Items []*ProjectionItem + Items []Expression } func NewProjection(distinct bool) *Projection { @@ -1226,6 +1250,10 @@ type Return struct { Projection *Projection } +func NewReturn() *Return { + return &Return{} +} + func (s *Return) copy() *Return { if s == nil { return nil @@ -1237,7 +1265,7 @@ func (s *Return) copy() *Return { } type PatternPart struct { - Binding string + Binding Expression ShortestPathPattern bool AllShortestPathsPattern bool PatternElements []*PatternElement @@ -1311,3 +1339,27 @@ func (s *Skip) copy() *Skip { Value: Copy(s.Value), } } + +type PatternPredicate struct { + PatternElements []*PatternElement +} + +func NewPatternPredicate() *PatternPredicate { + return &PatternPredicate{} +} + +func (s *PatternPredicate) AddElement(element Expression) { + s.PatternElements = append(s.PatternElements, &PatternElement{ + Element: element, + }) +} + +func (s *PatternPredicate) copy() *PatternPredicate { + if s == nil { + return nil + } + + return &PatternPredicate{ + PatternElements: Copy(s.PatternElements), + } +} diff --git a/packages/go/cypher/model/pg/extension.go b/packages/go/cypher/model/pg/extension.go new file mode 100644 index 0000000000..de518731af --- /dev/null +++ b/packages/go/cypher/model/pg/extension.go @@ -0,0 +1,85 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import "github.com/specterops/bloodhound/cypher/model" + +func Copy[T any](value T) T { + return model.Copy(value, func(value T) (T, bool) { + var valueCopy T + + switch typedValue := any(value).(type) { + case *AnnotatedVariable: + valueCopy = any(typedValue.copy()).(T) + + case *AnnotatedKindMatcher: + valueCopy = any(typedValue.copy()).(T) + + default: + return valueCopy, false + } + + return valueCopy, true + }) +} + +func CollectPGSQLTypes(nextCursor *model.WalkCursor, expression model.Expression) bool { + switch typedExpression := expression.(type) { + case *PropertiesReference: + model.Collect(nextCursor, typedExpression.Reference) + + case *AnnotatedPropertyLookup: + model.CollectExpression(nextCursor, typedExpression.Atom) + + case *AnnotatedKindMatcher: + model.CollectExpression(nextCursor, typedExpression.Reference) + + case *Entity: + model.Collect(nextCursor, typedExpression.Binding) + + case *Subquery: + model.CollectSlice(nextCursor, typedExpression.PatternElements) + model.CollectExpression(nextCursor, typedExpression.Filter) + + case *PropertyMutation: + model.Collect(nextCursor, typedExpression.Reference) + model.Collect(nextCursor, typedExpression.Removals) + model.Collect(nextCursor, typedExpression.Additions) + + case *Delete: + model.Collect(nextCursor, typedExpression.Binding) + + case *KindMutation: + model.Collect(nextCursor, typedExpression.Variable) + model.Collect(nextCursor, typedExpression.Removals) + model.Collect(nextCursor, typedExpression.Additions) + + case *NodeKindsReference: + model.CollectExpression(nextCursor, typedExpression.Variable) + + case *EdgeKindReference: + model.CollectExpression(nextCursor, typedExpression.Variable) + + case *AnnotatedLiteral, *AnnotatedVariable, *AnnotatedParameter: + // Valid types but no descent + + default: + return false + } + + return true +} diff --git a/packages/go/cypher/model/pg/model.go b/packages/go/cypher/model/pg/model.go new file mode 100644 index 0000000000..0d2fcce164 --- /dev/null +++ b/packages/go/cypher/model/pg/model.go @@ -0,0 +1,404 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "errors" + "fmt" + "github.com/jackc/pgtype" + "github.com/specterops/bloodhound/cypher/model" + pgModel "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + "github.com/specterops/bloodhound/dawgs/graph" + "time" +) + +var ( + ErrNonArrayDataType = errors.New("data type is not an array type") +) + +type DataType string + +const ( + UnknownDataType DataType = "UNKNOWN" + Reference DataType = "REFERENCE" + Null DataType = "NULL" + Node DataType = "nodeComposite" + NodeArray DataType = "nodeComposite[]" + Edge DataType = "edgeComposite" + EdgeArray DataType = "edgeComposite[]" + Path DataType = "pathComposite" + Int2 DataType = "int2" + Int2Array DataType = "int2[]" + Int4 DataType = "int4" + Int4Array DataType = "int4[]" + Int8 DataType = "int8" + Int8Array DataType = "int8[]" + Float4 DataType = "float4" + Float4Array DataType = "float4[]" + Float8 DataType = "float8" + Float8Array DataType = "float8[]" + Boolean DataType = "bool" + Text DataType = "text" + TextArray DataType = "text[]" + JSONB DataType = "jsonb" + Date DataType = "date" + TimeWithTimeZone DataType = "time with time zone" + TimeWithoutTimeZone DataType = "time without time zone" + Interval DataType = "interval" + TimestampWithTimeZone DataType = "timestamp with time zone" + TimestampWithoutTimeZone DataType = "timestamp without time zone" +) + +func (s DataType) IsArrayType() bool { + switch s { + case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray: + return true + } + + return false +} + +func (s DataType) ArrayBaseType() (DataType, error) { + switch s { + case Int2Array: + return Int2, nil + case Int4Array: + return Int4, nil + case Int8Array: + return Int8, nil + case Float4Array: + return Float4, nil + case Float8Array: + return Float8, nil + case TextArray: + return Text, nil + default: + return UnknownDataType, ErrNonArrayDataType + } +} + +func (s DataType) String() string { + return string(s) +} + +var CompositeTypes = []DataType{Node, NodeArray, Edge, EdgeArray, Path} + +type AnnotatedKindMatcher struct { + model.KindMatcher + Type DataType +} + +func NewAnnotatedKindMatcher(kindMatcher *model.KindMatcher, dataType DataType) *AnnotatedKindMatcher { + return &AnnotatedKindMatcher{ + KindMatcher: *kindMatcher, + Type: dataType, + } +} + +func (s *AnnotatedKindMatcher) copy() *AnnotatedKindMatcher { + return &AnnotatedKindMatcher{ + KindMatcher: model.KindMatcher{ + Reference: s.Reference, + Kinds: s.Kinds, + }, + Type: s.Type, + } +} + +type AnnotatedParameter struct { + model.Parameter + Type DataType +} + +func NewAnnotatedParameter(parameter *model.Parameter, dataType DataType) *AnnotatedParameter { + return &AnnotatedParameter{ + Parameter: *parameter, + Type: dataType, + } +} + +type Entity struct { + Binding *AnnotatedVariable +} + +func NewEntity(variable *AnnotatedVariable) *Entity { + return &Entity{ + Binding: variable, + } +} + +type AnnotatedVariable struct { + model.Variable + Type DataType +} + +func NewAnnotatedVariable(variable *model.Variable, dataType DataType) *AnnotatedVariable { + return &AnnotatedVariable{ + Variable: *variable, + Type: dataType, + } +} + +func (s *AnnotatedVariable) copy() *AnnotatedVariable { + if s == nil { + return nil + } + + return &AnnotatedVariable{ + Variable: model.Variable{ + Symbol: s.Symbol, + }, + Type: s.Type, + } +} + +type AnnotatedPropertyLookup struct { + model.PropertyLookup + Type DataType +} + +func NewAnnotatedPropertyLookup(propertyLookup *model.PropertyLookup, dataType DataType) *AnnotatedPropertyLookup { + return &AnnotatedPropertyLookup{ + PropertyLookup: *propertyLookup, + Type: dataType, + } +} + +type AnnotatedLiteral struct { + model.Literal + Type DataType +} + +func NewAnnotatedLiteral(literal *model.Literal, dataType DataType) *AnnotatedLiteral { + return &AnnotatedLiteral{ + Literal: *literal, + Type: dataType, + } +} + +func NewStringLiteral(value string) *AnnotatedLiteral { + return NewAnnotatedLiteral(model.NewStringLiteral(value), Text) +} + +type PropertiesReference struct { + Reference *AnnotatedVariable +} + +type Subquery struct { + PatternElements []*model.PatternElement + Filter model.Expression +} + +type SubQueryAnnotation struct { + FilterExpression model.Expression +} + +type SQLTypeAnnotation struct { + Type DataType +} + +func NewSQLTypeAnnotationFromExpression(expression model.Expression) (*SQLTypeAnnotation, error) { + switch typedExpression := expression.(type) { + case *model.Parameter: + return NewSQLTypeAnnotationFromValue(typedExpression.Value) + + case *model.Literal: + return NewSQLTypeAnnotationFromLiteral(typedExpression) + + case *model.ListLiteral: + var expectedTypeAnnotation *SQLTypeAnnotation + + for _, listExpressionItem := range *typedExpression { + if listExpressionItemLiteral, isLiteral := listExpressionItem.(*model.Literal); isLiteral { + if literalTypeAnnotation, err := NewSQLTypeAnnotationFromLiteral(listExpressionItemLiteral); err != nil { + return nil, err + } else if expectedTypeAnnotation != nil && expectedTypeAnnotation.Type != literalTypeAnnotation.Type { + return nil, fmt.Errorf("list literal contains mixed types") + } else { + expectedTypeAnnotation = literalTypeAnnotation + } + } + } + + return expectedTypeAnnotation, nil + + default: + return nil, fmt.Errorf("unsupported expression type %T for SQL type annotation", expression) + } +} + +func NewSQLTypeAnnotationFromLiteral(literal *model.Literal) (*SQLTypeAnnotation, error) { + if literal.Null { + return &SQLTypeAnnotation{ + Type: Null, + }, nil + } + + return NewSQLTypeAnnotationFromValue(literal.Value) +} + +func NewSQLTypeAnnotationFromValue(value any) (*SQLTypeAnnotation, error) { + if value == nil { + return &SQLTypeAnnotation{ + Type: Null, + }, nil + } + + switch typedValue := value.(type) { + case []uint16, []int16, pgtype.Int2Array: + return &SQLTypeAnnotation{ + Type: Int2Array, + }, nil + + case []uint32, []int32, []graph.ID, pgtype.Int4Array: + return &SQLTypeAnnotation{ + Type: Int4Array, + }, nil + + case []uint64, []int64, pgtype.Int8Array: + return &SQLTypeAnnotation{ + Type: Int8Array, + }, nil + + case uint16, int16: + return &SQLTypeAnnotation{ + Type: Int2, + }, nil + + case uint32, int32, graph.ID: + return &SQLTypeAnnotation{ + Type: Int4, + }, nil + + case uint, int, uint64, int64: + return &SQLTypeAnnotation{ + Type: Int8, + }, nil + + case float32: + return &SQLTypeAnnotation{ + Type: Float4, + }, nil + + case []float32: + return &SQLTypeAnnotation{ + Type: Float4Array, + }, nil + + case float64: + return &SQLTypeAnnotation{ + Type: Float8, + }, nil + + case []float64: + return &SQLTypeAnnotation{ + Type: Float8Array, + }, nil + + case bool: + return &SQLTypeAnnotation{ + Type: Boolean, + }, nil + + case string: + return &SQLTypeAnnotation{ + Type: Text, + }, nil + + case time.Time: + return &SQLTypeAnnotation{ + Type: TimestampWithTimeZone, + }, nil + + case pgtype.JSONB: + return &SQLTypeAnnotation{ + Type: JSONB, + }, nil + + case []string, pgtype.TextArray: + return &SQLTypeAnnotation{ + Type: TextArray, + }, nil + + case *model.ListLiteral: + return NewSQLTypeAnnotationFromExpression(typedValue) + + default: + return nil, fmt.Errorf("literal type %T is not supported", value) + } +} + +type NodeKindsReference struct { + Variable model.Expression +} + +func NewNodeKindsReference(ref *AnnotatedVariable) *NodeKindsReference { + return &NodeKindsReference{ + Variable: ref, + } +} + +type EdgeKindReference struct { + Variable model.Expression +} + +func NewEdgeKindReference(ref *AnnotatedVariable) *EdgeKindReference { + return &EdgeKindReference{ + Variable: ref, + } +} + +type Delete struct { + Binding *AnnotatedVariable + NodeDelete bool + EdgeDelete bool +} + +func NewDelete() *Delete { + return &Delete{ + NodeDelete: false, + EdgeDelete: false, + } +} + +func (s *Delete) IsMixed() bool { + return s.NodeDelete && s.EdgeDelete +} + +func (s *Delete) Table() string { + if s.NodeDelete { + return pgModel.NodeTable + } + + if s.EdgeDelete { + return pgModel.EdgeTable + } + + return "" +} + +type PropertyMutation struct { + Reference *PropertiesReference + Additions *AnnotatedParameter + Removals *AnnotatedParameter +} + +type KindMutation struct { + Variable *AnnotatedVariable + Additions *AnnotatedParameter + Removals *AnnotatedParameter +} diff --git a/packages/go/cypher/model/visitor.go b/packages/go/cypher/model/visitor.go deleted file mode 100644 index e13afcc574..0000000000 --- a/packages/go/cypher/model/visitor.go +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2023 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 - -package model - -import ( - "fmt" - - "github.com/specterops/bloodhound/dawgs/graph" -) - -type Visitor func(parent, node any) error - -func walkList[T any](enter, exit Visitor, parent any, nodeList []T) error { - for i := 0; i < len(nodeList); i++ { - if err := walkNodes(enter, exit, parent, nodeList[i]); err != nil { - return err - } - } - - return nil -} - -func walkNodes(enter, exit Visitor, parent any, nodes ...any) error { - for _, node := range nodes { - if enter != nil { - if err := enter(parent, node); err != nil { - return err - } - } - - switch typedNode := node.(type) { - case ExpressionList: - for idx := 0; idx < typedNode.Len(); idx++ { - expression := typedNode.Get(idx) - - if err := walkNodes(enter, exit, node, expression); err != nil { - return err - } - } - - case *RegularQuery: - if err := walkNodes(enter, exit, node, typedNode.SingleQuery); err != nil { - return err - } - - case *SingleQuery: - if typedNode.SinglePartQuery != nil { - if err := walkNodes(enter, exit, node, typedNode.SinglePartQuery); err != nil { - return err - } - } else if typedNode.MultiPartQuery != nil { - if err := walkNodes(enter, exit, node, typedNode.MultiPartQuery); err != nil { - return err - } - } - - case *MultiPartQuery: - if err := walkList(enter, exit, typedNode, typedNode.Parts); err != nil { - return err - } - - if err := walkNodes(enter, exit, typedNode, typedNode.SinglePartQuery); err != nil { - return err - } - - case *MultiPartQueryPart: - if err := walkList(enter, exit, typedNode, typedNode.ReadingClauses); err != nil { - return err - } - - if err := walkList(enter, exit, typedNode, typedNode.UpdatingClauses); err != nil { - return err - } - - if typedNode.With != nil { - if err := walkNodes(enter, exit, typedNode, typedNode.With); err != nil { - return err - } - } - - case *Quantifier: - if err := walkNodes(enter, exit, typedNode, typedNode.Filter); err != nil { - return err - } - - case *FilterExpression: - if err := walkNodes(enter, exit, typedNode, typedNode.Specifier); err != nil { - return err - } - - if typedNode.Where != nil { - if err := walkNodes(enter, exit, typedNode, typedNode.Where); err != nil { - return err - } - } - - case *IDInCollection: - if err := walkNodes(enter, exit, typedNode, typedNode.Variable); err != nil { - return err - } - - if err := walkNodes(enter, exit, typedNode, typedNode.Expression); err != nil { - return err - } - - case *With: - if err := walkNodes(enter, exit, node, typedNode.Projection); err != nil { - return err - } - - if typedNode.Where != nil { - if err := walkNodes(enter, exit, node, typedNode.Where); err != nil { - return err - } - } - - case *Unwind: - if err := walkNodes(enter, exit, node, typedNode.Expression); err != nil { - return err - } - - case *ReadingClause: - if typedNode.Match != nil { - if err := walkNodes(enter, exit, node, typedNode.Match); err != nil { - return err - } - } - - if typedNode.Unwind != nil { - if err := walkNodes(enter, exit, node, typedNode.Unwind); err != nil { - return err - } - } - - case *SinglePartQuery: - if err := walkList(enter, exit, node, typedNode.ReadingClauses); err != nil { - return err - } - - if err := walkList(enter, exit, node, typedNode.UpdatingClauses); err != nil { - return err - } - - if typedNode.Return != nil { - if err := walkNodes(enter, exit, node, typedNode.Return); err != nil { - return err - } - } - - case *Remove: - if err := walkList(enter, exit, node, typedNode.Items); err != nil { - return err - } - - case *Set: - if err := walkList(enter, exit, node, typedNode.Items); err != nil { - return err - } - - case *SetItem: - if err := walkNodes(enter, exit, node, typedNode.Right, typedNode.Left); err != nil { - return err - } - - case *Negation: - if err := walkNodes(enter, exit, node, typedNode.Expression); err != nil { - return err - } - - case *PartialComparison: - if err := walkNodes(enter, exit, node, typedNode.Right); err != nil { - return err - } - - case *Parenthetical: - if err := walkNodes(enter, exit, node, typedNode.Expression); err != nil { - return err - } - - case *PatternElement: - if err := walkNodes(enter, exit, typedNode, typedNode.Element); err != nil { - return err - } - - case *Match: - if typedNode.Where != nil { - if err := walkNodes(enter, exit, node, typedNode.Where); err != nil { - return err - } - } - - if typedNode.Pattern != nil { - if err := walkList(enter, exit, node, typedNode.Pattern); err != nil { - return err - } - } - - case *Create: - if err := walkList(enter, exit, node, typedNode.Pattern); err != nil { - return err - } - - case *Return: - if err := walkNodes(enter, exit, node, typedNode.Projection); err != nil { - return err - } - - case *FunctionInvocation: - if err := walkList(enter, exit, node, typedNode.Arguments); err != nil { - return err - } - - case *Comparison: - if err := walkNodes(enter, exit, node, typedNode.Left); err != nil { - return err - } - - if err := walkList(enter, exit, node, typedNode.Partials); err != nil { - return err - } - - case []*PatternPart: - if err := walkList(enter, exit, parent, typedNode); err != nil { - return err - } - - case *SortItem: - if err := walkNodes(enter, exit, typedNode, typedNode.Expression); err != nil { - return err - } - - case *Order: - if err := walkList(enter, exit, typedNode, typedNode.Items); err != nil { - return err - } - - case *Projection: - if err := walkList(enter, exit, node, typedNode.Items); err != nil { - return err - } - - if typedNode.Order != nil { - if err := walkNodes(enter, exit, typedNode, typedNode.Order); err != nil { - return err - } - } - - case *ProjectionItem: - if err := walkNodes(enter, exit, node, typedNode.Expression); err != nil { - return err - } - - case *ArithmeticExpression: - if err := walkNodes(enter, exit, node, typedNode.Left); err != nil { - return err - } - - if err := walkList(enter, exit, node, typedNode.Partials); err != nil { - return err - } - - case *PartialArithmeticExpression: - if err := walkNodes(enter, exit, node, typedNode.Right); err != nil { - return err - } - - case *Delete: - for _, expression := range typedNode.Expressions { - if err := walkNodes(enter, exit, node, expression); err != nil { - return err - } - } - - case *KindMatcher: - if err := walkNodes(enter, exit, node, typedNode.Reference); err != nil { - return err - } - - case *RemoveItem: - if typedNode.KindMatcher != nil { - if err := walkNodes(enter, exit, node, typedNode.KindMatcher); err != nil { - return err - } - } - - case *PropertyLookup: - if err := walkNodes(enter, exit, node, typedNode.Atom); err != nil { - return err - } - - case *UpdatingClause: - if err := walkNodes(enter, exit, node, typedNode.Clause); err != nil { - return err - } - - case *NodePattern: - if err := walkNodes(enter, exit, node, typedNode.Properties); err != nil { - return err - } - - case *PatternPart: - if err := walkList(enter, exit, node, typedNode.PatternElements); err != nil { - return err - } - - case *RelationshipPattern: - if err := walkNodes(enter, exit, node, typedNode.Properties); err != nil { - return err - } - - case *Properties: - if err := walkNodes(enter, exit, node, typedNode.Parameter); err != nil { - return err - } - - case *Variable, *Literal, *Parameter, *RangeQuantifier, graph.Kinds: - // Valid model elements but no further descent required - - case nil: - default: - return fmt.Errorf("unsupported type for model traversal %T(%T)", parent, node) - } - - if exit != nil { - if err := exit(parent, node); err != nil { - return err - } - } - } - - return nil -} - -// Walk is a recursive, depth-first traversal implementation for the openCypher query model. -func Walk(element any, enter, exit Visitor) error { - return walkNodes(enter, exit, nil, element) -} diff --git a/packages/go/cypher/model/walk.go b/packages/go/cypher/model/walk.go new file mode 100644 index 0000000000..5297295a80 --- /dev/null +++ b/packages/go/cypher/model/walk.go @@ -0,0 +1,347 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "fmt" + "github.com/specterops/bloodhound/dawgs/graph" +) + +type WalkCursor struct { + Trunk Expression + Branches []Expression + currentBranch int +} + +func (s *WalkCursor) CurrentBranch() Expression { + return s.Branches[s.currentBranch] +} + +func (s *WalkCursor) next() (Expression, bool) { + if s.currentBranch < len(s.Branches) { + next := s.Branches[s.currentBranch] + s.currentBranch++ + + return next, true + } + + return nil, false +} + +type WalkStack struct { + stack []*WalkCursor +} + +func newStack(root Expression) *WalkStack { + return &WalkStack{ + stack: []*WalkCursor{{ + Branches: []Expression{root}, + currentBranch: 0, + }}, + } +} + +func (s *WalkStack) Push(trunk Expression) *WalkCursor { + cursor := &WalkCursor{ + Trunk: trunk, + currentBranch: 0, + } + + s.stack = append(s.stack, cursor) + return cursor +} + +func (s *WalkStack) Trunk() Expression { + if s.Empty() { + return nil + } + + return s.Peek().Trunk +} + +func (s *WalkStack) Empty() bool { + return len(s.stack) == 0 +} + +func (s *WalkStack) Peek() *WalkCursor { + return s.stack[len(s.stack)-1] +} + +func (s *WalkStack) PeekAt(depth int) *WalkCursor { + if index := len(s.stack) - depth - 1; depth >= 0 { + return s.stack[index] + } + + return nil +} + +func (s *WalkStack) Pop() { + s.stack = s.stack[:len(s.stack)-1] +} + +func CollectExpression(cursor *WalkCursor, expression Expression) { + if expression != nil { + cursor.Branches = append(cursor.Branches, expression) + } +} + +func CollectExpressions(cursor *WalkCursor, expressions []Expression) { + for _, expression := range expressions { + CollectExpression(cursor, expression) + } +} + +func Collect[T any](cursor *WalkCursor, expression *T) { + if expression != nil { + CollectExpression(cursor, expression) + } +} + +func CollectSlice[T any](cursor *WalkCursor, expressions []*T) { + for _, expression := range expressions { + Collect(cursor, expression) + } +} + +type Visitor interface { + Enter(stack *WalkStack, expression Expression) error + Exit(stack *WalkStack, expression Expression) error +} + +type VisitorFunc func(stack *WalkStack, branch Expression) error + +type visitor struct { + enterVisitor VisitorFunc + exitVisitor VisitorFunc +} + +func NewVisitor(enterVisitor VisitorFunc, exitVisitor VisitorFunc) Visitor { + return visitor{ + enterVisitor: enterVisitor, + exitVisitor: exitVisitor, + } +} + +func (s visitor) Enter(stack *WalkStack, expression Expression) error { + if s.enterVisitor != nil { + return s.enterVisitor(stack, expression) + } + + return nil +} + +func (s visitor) Exit(stack *WalkStack, expression Expression) error { + if s.exitVisitor != nil { + return s.exitVisitor(stack, expression) + } + + return nil +} + +type CollectorFunc func(nextCursor *WalkCursor, expression Expression) bool + +func cypherModelCollect(nextCursor *WalkCursor, expression Expression) bool { + switch typedExpr := expression.(type) { + case ExpressionList: + CollectExpressions(nextCursor, typedExpr.GetAll()) + + case *RegularQuery: + Collect(nextCursor, typedExpr.SingleQuery) + + case *SingleQuery: + Collect(nextCursor, typedExpr.SinglePartQuery) + Collect(nextCursor, typedExpr.MultiPartQuery) + + case *MultiPartQuery: + CollectSlice(nextCursor, typedExpr.Parts) + Collect(nextCursor, typedExpr.SinglePartQuery) + + case *MultiPartQueryPart: + CollectSlice(nextCursor, typedExpr.ReadingClauses) + CollectSlice(nextCursor, typedExpr.UpdatingClauses) + Collect(nextCursor, typedExpr.With) + + case *Quantifier: + Collect(nextCursor, typedExpr.Filter) + + case *FilterExpression: + Collect(nextCursor, typedExpr.Specifier) + Collect(nextCursor, typedExpr.Where) + + case *IDInCollection: + Collect(nextCursor, typedExpr.Variable) + CollectExpression(nextCursor, typedExpr.Expression) + + case *With: + Collect(nextCursor, typedExpr.Projection) + Collect(nextCursor, typedExpr.Where) + + case *Unwind: + CollectExpression(nextCursor, typedExpr.Expression) + Collect(nextCursor, typedExpr.Binding) + + case *ReadingClause: + Collect(nextCursor, typedExpr.Match) + Collect(nextCursor, typedExpr.Unwind) + + case *SinglePartQuery: + CollectSlice(nextCursor, typedExpr.ReadingClauses) + CollectExpressions(nextCursor, typedExpr.UpdatingClauses) + Collect(nextCursor, typedExpr.Return) + + case *Remove: + CollectSlice(nextCursor, typedExpr.Items) + + case *Set: + CollectSlice(nextCursor, typedExpr.Items) + + case *SetItem: + CollectExpression(nextCursor, typedExpr.Left) + CollectExpression(nextCursor, typedExpr.Right) + + case *Negation: + CollectExpression(nextCursor, typedExpr.Expression) + + case *PartialComparison: + CollectExpression(nextCursor, typedExpr.Right) + + case *Parenthetical: + CollectExpression(nextCursor, typedExpr.Expression) + + case *PatternElement: + CollectExpression(nextCursor, typedExpr.Element) + + case *Match: + Collect(nextCursor, typedExpr.Where) + CollectSlice(nextCursor, typedExpr.Pattern) + + case *Create: + CollectSlice(nextCursor, typedExpr.Pattern) + + case *Return: + Collect(nextCursor, typedExpr.Projection) + + case *FunctionInvocation: + CollectExpressions(nextCursor, typedExpr.Arguments) + + case *Comparison: + CollectExpression(nextCursor, typedExpr.Left) + CollectSlice(nextCursor, typedExpr.Partials) + + case *PatternPredicate: + CollectSlice(nextCursor, typedExpr.PatternElements) + + case *SortItem: + CollectExpression(nextCursor, typedExpr.Expression) + + case *Order: + CollectSlice(nextCursor, typedExpr.Items) + + case *Projection: + CollectExpressions(nextCursor, typedExpr.Items) + Collect(nextCursor, typedExpr.Order) + + case *ProjectionItem: + CollectExpression(nextCursor, typedExpr.Expression) + CollectExpression(nextCursor, typedExpr.Binding) + + case *ArithmeticExpression: + CollectExpression(nextCursor, typedExpr.Left) + CollectSlice(nextCursor, typedExpr.Partials) + + case *PartialArithmeticExpression: + CollectExpression(nextCursor, typedExpr.Right) + + case *Delete: + CollectExpressions(nextCursor, typedExpr.Expressions) + + case *KindMatcher: + CollectExpression(nextCursor, typedExpr.Reference) + + case *RemoveItem: + CollectExpression(nextCursor, typedExpr.KindMatcher) + Collect(nextCursor, typedExpr.Property) + + case *PropertyLookup: + CollectExpression(nextCursor, typedExpr.Atom) + + case *UpdatingClause: + CollectExpression(nextCursor, typedExpr.Clause) + + case *NodePattern: + CollectExpression(nextCursor, typedExpr.Properties) + CollectExpression(nextCursor, typedExpr.Binding) + + case *PatternPart: + CollectSlice(nextCursor, typedExpr.PatternElements) + CollectExpression(nextCursor, typedExpr.Binding) + + case *RelationshipPattern: + CollectExpression(nextCursor, typedExpr.Properties) + CollectExpression(nextCursor, typedExpr.Binding) + + case *Properties: + Collect(nextCursor, typedExpr.Parameter) + + case *Variable, *Literal, *Parameter, *RangeQuantifier, graph.Kinds: + // Valid model elements but no further descent required + + case nil: + default: + return false + } + + return true +} + +func Walk(root Expression, visitor Visitor, extensions ...CollectorFunc) error { + stack := newStack(root) + + for !stack.Empty() { + currentCursor := stack.Peek() + + if nextExpr, hasNext := currentCursor.next(); hasNext { + // On enter of new node + if err := visitor.Enter(stack, nextExpr); err != nil { + return err + } + + if nextCursor := stack.Push(nextExpr); !cypherModelCollect(nextCursor, nextExpr) { + collected := false + + for _, extension := range extensions { + if extension(nextCursor, nextExpr) { + collected = true + break + } + } + + if !collected { + return fmt.Errorf("unsupported type for model traversal %T", nextExpr) + } + } + } else { + stack.Pop() + + if err := visitor.Exit(stack, currentCursor.Trunk); err != nil { + return err + } + } + } + + return nil +} diff --git a/packages/go/cypher/model/visitor_test.go b/packages/go/cypher/model/walk_test.go similarity index 81% rename from packages/go/cypher/model/visitor_test.go rename to packages/go/cypher/model/walk_test.go index 63f372326d..3500ba974d 100644 --- a/packages/go/cypher/model/visitor_test.go +++ b/packages/go/cypher/model/walk_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package model_test @@ -19,17 +19,27 @@ package model_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/cypher/frontend" "github.com/specterops/bloodhound/cypher/model" "github.com/specterops/bloodhound/cypher/test" + "github.com/stretchr/testify/require" ) +type walker struct{} + +func (w walker) Enter(stack *model.WalkStack, expression model.Expression) error { + return nil +} + +func (w walker) Exit(stack *model.WalkStack, expression model.Expression) error { + return nil +} + func TestWalk(t *testing.T) { // Walk through all positive test cases to ensure that the walker can visit the involved types for _, testCase := range test.LoadFixture(t, test.PositiveTestCases).RunnableCases() { // Only bother with the string match tests - if testCase.Type == test.TestTypeStringMatch { + if testCase.Type == test.TypeStringMatch { var ( details = test.UnmarshallTestCaseDetails[test.StringMatchTest](t, testCase) parseContext = frontend.NewContext() @@ -40,7 +50,7 @@ func TestWalk(t *testing.T) { t.Fatalf("Parser errors: %s", parseErr.Error()) } - require.Nil(t, model.Walk(queryModel, nil, nil)) + require.Nil(t, model.Walk(queryModel, &walker{})) } } } diff --git a/packages/go/cypher/test/cases/positive_tests.json b/packages/go/cypher/test/cases/positive_tests.json index f42b47a4da..4286828f41 100644 --- a/packages/go/cypher/test/cases/positive_tests.json +++ b/packages/go/cypher/test/cases/positive_tests.json @@ -421,7 +421,7 @@ "name": "Eliminate duplication in lists", "type": "string_match", "details": { - "query": "match (p:Person)-[:ACTED_IN]->(m:Movie) where m.year = 1920 return collect(distinct(m.title))", + "query": "match (p:Person)-[:ACTED_IN]->(m:Movie) where m.year = 1920 return collect(distinct (m.title))", "complexity": 4.0 } }, @@ -462,7 +462,7 @@ "type": "string_match", "details": { "query": "match (b) where (b)<-[]->() return b", - "complexity": 11.0 + "complexity": 9.0 } }, { @@ -470,7 +470,7 @@ "type": "string_match", "details": { "query": "match (b) where not ((b)<-[]->()) return b", - "complexity": 11.0 + "complexity": 9.0 } }, { @@ -958,7 +958,7 @@ "name": "Find Kerberoastable Users with most privileges", "type": "string_match", "details": { - "query": "match (u:User {hasspn: true}) optional match (u)-[:AdminTo]->(c1:Computer) optional match (u)-[:MemberOf*1..]->(:Group)-[:AdminTo]->(c2:Computer) with u, collect(c1) + collect(c2) as tempVar unwind tempVar as comps return u.name, count(distinct(comps)) order by count(distinct(comps)) desc", + "query": "match (u:User {hasspn: true}) optional match (u)-[:AdminTo]->(c1:Computer) optional match (u)-[:MemberOf*1..]->(:Group)-[:AdminTo]->(c2:Computer) with u, collect(c1) + collect(c2) as tempVar unwind tempVar as comps return u.name, count(distinct (comps)) order by count(distinct (comps)) desc", "complexity": 18.0 } }, diff --git a/packages/go/cypher/test/test.go b/packages/go/cypher/test/test.go index 674bbcfacf..e8cbe33be4 100644 --- a/packages/go/cypher/test/test.go +++ b/packages/go/cypher/test/test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package test @@ -19,24 +19,25 @@ package test import ( "embed" "encoding/json" + "github.com/specterops/bloodhound/cypher/backend" "regexp" "testing" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/cypher/frontend" + "github.com/stretchr/testify/require" ) //go:embed cases var fixtureFS embed.FS -type TestType = string +type Type = string const ( - TestTypeStringMatch TestType = "string_match" - TestTypeNegativeCase TestType = "negative_case" + TypeStringMatch Type = "string_match" + TypeNegativeCase Type = "negative_case" ) -type TestRunner interface { +type Runner interface { Run(t *testing.T, testCase Case) } @@ -98,7 +99,7 @@ type StringMatchTest struct { func (s StringMatchTest) Run(t *testing.T, testCase Case) { var ( ctx = frontend.NewContext() - result, err = frontend.CypherToCypher(ctx, s.Query) + result, err = backend.CypherToCypher(ctx, s.Query) ) if err != nil { @@ -117,7 +118,7 @@ func (s StringMatchTest) Run(t *testing.T, testCase Case) { type Case struct { Name string `json:"name"` - Type TestType `json:"type"` + Type Type `json:"type"` Targeted bool `json:"targeted"` Ignore bool `json:"ignore"` Details json.RawMessage `json:"details"` @@ -189,7 +190,7 @@ func LoadFixture(t *testing.T, filename string) Cases { return fixture } -func testRunner[T TestRunner](testCase Case) func(t *testing.T) { +func testRunner[T Runner](testCase Case) func(t *testing.T) { return func(t *testing.T) { // Run the test case if it isn't ignored if !testCase.Ignore { @@ -200,10 +201,10 @@ func testRunner[T TestRunner](testCase Case) func(t *testing.T) { func testCase(test Case) func(t *testing.T) { switch test.Type { - case TestTypeStringMatch: + case TypeStringMatch: return testRunner[StringMatchTest](test) - case TestTypeNegativeCase: + case TypeNegativeCase: return testRunner[NegativeTest](test) default: diff --git a/packages/go/dawgs/cardinality/graph.go b/packages/go/dawgs/cardinality/graph.go index 9fe767166a..ee4001ea1d 100644 --- a/packages/go/dawgs/cardinality/graph.go +++ b/packages/go/dawgs/cardinality/graph.go @@ -62,7 +62,7 @@ func (s KindBitmaps) Or(bitmaps KindBitmaps) { func (s KindBitmaps) AddSets(nodeSets ...graph.NodeSet) { for _, nodeSet := range nodeSets { for _, node := range nodeSet { - s.AddIDKindsPair(node.ID, node.Kinds) + s.AddIDToKinds(node.ID, node.Kinds) } } } @@ -87,17 +87,46 @@ func (s KindBitmaps) Contains(node *graph.Node) bool { return false } -func (s KindBitmaps) AddIDKindsPair(id graph.ID, kinds graph.Kinds) { +func (s KindBitmaps) AddDuplexToKind(ids Duplex[uint32], kind graph.Kind) { + kindStr := kind.String() + + if bitmap, hasBitmap := s[kindStr]; !hasBitmap { + newBitmap := NewBitmap32() + newBitmap.Or(ids) + + s[kindStr] = newBitmap + } else { + bitmap.Or(ids) + } +} + +func (s KindBitmaps) AddIDToKind(id graph.ID, kind graph.Kind) { + var ( + nodeID = id.Uint32() + kindStr = kind.String() + ) + + if bitmap, hasBitmap := s[kindStr]; !hasBitmap { + newBitmap := NewBitmap32() + newBitmap.Add(nodeID) + + s[kindStr] = newBitmap + } else { + bitmap.Add(nodeID) + } +} + +func (s KindBitmaps) AddIDToKinds(id graph.ID, kinds graph.Kinds) { nodeID := id.Uint32() - for _, nodeKind := range kinds { - nodeKindStr := nodeKind.String() + for _, kind := range kinds { + kindStr := kind.String() - if bitmap, hasBitmap := s[nodeKindStr]; !hasBitmap { + if bitmap, hasBitmap := s[kindStr]; !hasBitmap { newBitmap := NewBitmap32() newBitmap.Add(nodeID) - s[nodeKindStr] = newBitmap + s[kindStr] = newBitmap } else { bitmap.Add(nodeID) } @@ -106,7 +135,7 @@ func (s KindBitmaps) AddIDKindsPair(id graph.ID, kinds graph.Kinds) { func (s KindBitmaps) AddNodes(nodes ...*graph.Node) { for _, node := range nodes { - s.AddIDKindsPair(node.ID, node.Kinds) + s.AddIDToKinds(node.ID, node.Kinds) } } diff --git a/packages/go/dawgs/cardinality/hyperloglog_bench_test.go b/packages/go/dawgs/cardinality/hyperloglog_bench_test.go index bfb55246e7..d394d7f3d6 100644 --- a/packages/go/dawgs/cardinality/hyperloglog_bench_test.go +++ b/packages/go/dawgs/cardinality/hyperloglog_bench_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package cardinality_test diff --git a/packages/go/dawgs/cardinality/roaring32.go b/packages/go/dawgs/cardinality/roaring32.go index 79cb4dd581..90622cd170 100644 --- a/packages/go/dawgs/cardinality/roaring32.go +++ b/packages/go/dawgs/cardinality/roaring32.go @@ -42,6 +42,13 @@ func NewBitmap32() Duplex[uint32] { } } +func NewBitmap32With(values ...uint32) Duplex[uint32] { + duplex := NewBitmap32() + duplex.Add(values...) + + return duplex +} + func (s bitmap32) Clear() { s.bitmap.Clear() } diff --git a/packages/go/dawgs/dawgs.go b/packages/go/dawgs/dawgs.go index f273bcfae9..bd83e8d6a2 100644 --- a/packages/go/dawgs/dawgs.go +++ b/packages/go/dawgs/dawgs.go @@ -17,6 +17,7 @@ package dawgs import ( + "context" "errors" "github.com/specterops/bloodhound/dawgs/graph" @@ -27,7 +28,7 @@ var ( ErrDriverMissing = errors.New("driver missing") ) -type DriverConstructor func(cfg Config) (graph.Database, error) +type DriverConstructor func(ctx context.Context, cfg Config) (graph.Database, error) var availableDrivers = map[string]DriverConstructor{} @@ -40,10 +41,10 @@ type Config struct { DriverCfg any } -func Open(driverName string, config Config) (graph.Database, error) { +func Open(ctx context.Context, driverName string, config Config) (graph.Database, error) { if driverConstructor, hasDriver := availableDrivers[driverName]; !hasDriver { return nil, ErrDriverMissing } else { - return driverConstructor(config) + return driverConstructor(ctx, config) } } diff --git a/packages/go/dawgs/drivers/neo4j/batch.go b/packages/go/dawgs/drivers/neo4j/batch.go index c20999585d..14daefd594 100644 --- a/packages/go/dawgs/drivers/neo4j/batch.go +++ b/packages/go/dawgs/drivers/neo4j/batch.go @@ -42,6 +42,19 @@ type batchTransaction struct { batchWriteSize int } +func (s *batchTransaction) CreateNode(node *graph.Node) error { + _, err := s.innerTx.CreateNode(node.Properties, node.Kinds...) + return err +} + +func (s *batchTransaction) CreateRelationship(relationship *graph.Relationship) error { + return s.CreateRelationshipByIDs(relationship.StartID, relationship.EndID, relationship.Kind, relationship.Properties) +} + +func (s *batchTransaction) WithGraph(graphSchema graph.Graph) graph.Batch { + return s +} + func (s *batchTransaction) Nodes() graph.NodeQuery { return NewNodeQuery(s.innerTx.ctx, s) } @@ -112,35 +125,10 @@ func (s *batchTransaction) Close() error { return s.innerTx.Close() } -func (s *batchTransaction) CreateNode(properties *graph.Properties, kinds ...graph.Kind) error { - _, err := s.innerTx.CreateNode(properties, kinds...) - return err -} - func (s *batchTransaction) UpdateNode(target *graph.Node) error { return s.innerTx.UpdateNode(target) } -func (s *batchTransaction) CreateRelationship(startNode, endNode *graph.Node, kind graph.Kind, properties *graph.Properties) error { - if startNode.ID == graph.UnregisteredNodeID { - if newStartNode, err := s.innerTx.CreateNode(startNode.Properties, startNode.Kinds...); err != nil { - return err - } else { - startNode = newStartNode - } - } - - if endNode.ID == graph.UnregisteredNodeID { - if newEndNode, err := s.innerTx.CreateNode(endNode.Properties, endNode.Kinds...); err != nil { - return err - } else { - endNode = newEndNode - } - } - - return s.CreateRelationshipByIDs(startNode.ID, endNode.ID, kind, properties) -} - func (s *batchTransaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error { nextUpdate := createRelationshipByIDs{ startID: startNodeID, @@ -176,8 +164,8 @@ func (s *batchTransaction) UpdateRelationship(relationship *graph.Relationship) return s.innerTx.UpdateRelationship(relationship) } -func (s *batchTransaction) Run(cypher string, params map[string]any) graph.Result { - return s.innerTx.Run(cypher, params) +func (s *batchTransaction) Raw(cypher string, params map[string]any) graph.Result { + return s.innerTx.Raw(cypher, params) } type relationshipCreateByIDBatch struct { diff --git a/packages/go/dawgs/drivers/neo4j/cypher.go b/packages/go/dawgs/drivers/neo4j/cypher.go index f52df52187..9f3386a721 100644 --- a/packages/go/dawgs/drivers/neo4j/cypher.go +++ b/packages/go/dawgs/drivers/neo4j/cypher.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j @@ -67,8 +67,8 @@ func (s relUpdateByMap) add(update graph.RelationshipUpdate) { updateKey = relUpdateKey(update) updateProperties = map[string]any{ "r": update.Relationship.Properties.Map, - "s": update.StartIdentityPropertiesMap(), - "e": update.EndIdentityPropertiesMap(), + "s": update.Start.Properties.Map, + "e": update.End.Properties.Map, } ) @@ -171,7 +171,7 @@ func cypherBuildRelationshipUpdateQueryBatch(updates []graph.RelationshipUpdate) output.WriteString("}") } - output.WriteString("]->(e) set r += p.r") + output.WriteString("]->(e) set s += p.s, e += p.e, r += p.r") if len(batch.startNodeKindsToAdd) > 0 { for _, kindToAdd := range batch.startNodeKindsToAdd { @@ -187,8 +187,7 @@ func cypherBuildRelationshipUpdateQueryBatch(updates []graph.RelationshipUpdate) } } - output.WriteString(", s.lastseen = datetime({timezone: 'UTC'}), e.lastseen = datetime({timezone: 'UTC'})") - output.WriteString(";") + output.WriteString(", s.lastseen = datetime({timezone: 'UTC'}), e.lastseen = datetime({timezone: 'UTC'});") // Write out the query to be run queries = append(queries, output.String()) diff --git a/packages/go/dawgs/drivers/neo4j/driver.go b/packages/go/dawgs/drivers/neo4j/driver.go index 806b546695..58a811b77e 100644 --- a/packages/go/dawgs/drivers/neo4j/driver.go +++ b/packages/go/dawgs/drivers/neo4j/driver.go @@ -18,7 +18,6 @@ package neo4j import ( "context" - "fmt" "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j" @@ -87,7 +86,7 @@ func (s *driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDe return batch.Commit() } -func (s *driver) Close() error { +func (s *driver) Close(ctx context.Context) error { return s.driver.Close() } @@ -132,62 +131,15 @@ func (s *driver) WriteTransaction(ctx context.Context, txDelegate graph.Transact return s.transaction(ctx, txDelegate, session, options) } -func (s *driver) FetchSchema(ctx context.Context) (*graph.Schema, error) { - schema := graph.NewSchema() - - return schema, s.ReadTransaction(ctx, func(tx graph.Transaction) error { - if result := tx.Run("call db.indexes() yield name, uniqueness, provider, labelsOrTypes, properties;", nil); result.Error() != nil { - return result.Error() - } else { - defer result.Close() - - var ( - name string - uniqueness string - provider string - labels []string - properties []string - ) - - for result.Next() { - if err := result.Scan(&name, &uniqueness, &provider, &labels, &properties); err != nil { - return err - } - - // Need this for neo4j 4.4+ which creates a weird index by default - if len(labels) == 0 { - continue - } - - if len(labels) > 1 || len(properties) > 1 { - return fmt.Errorf("composite index types are currently not supported") - } - - label := labels[0] - property := properties[0] - - if uniqueness == "UNIQUE" { - schema.EnsureKind(graph.StringKind(label)).Constraint(property, name, parseProviderType(provider)) - } else { - schema.EnsureKind(graph.StringKind(label)).Index(property, name, parseProviderType(provider)) - } - } - - return result.Error() - } - }) -} - -func (s *driver) AssertSchema(ctx context.Context, schema *graph.Schema) error { - if existingSchema, err := s.FetchSchema(ctx); err != nil { - return fmt.Errorf("could not load schema: %w", err) - } else { - return assertAgainst(ctx, schema, existingSchema, s) - } +func (s *driver) AssertSchema(ctx context.Context, schema graph.Schema) error { + return assertSchema(ctx, s, schema) } func (s *driver) Run(ctx context.Context, query string, parameters map[string]any) error { return s.WriteTransaction(ctx, func(tx graph.Transaction) error { - return tx.Run(query, parameters).Error() + result := tx.Raw(query, parameters) + defer result.Close() + + return result.Error() }) } diff --git a/packages/go/dawgs/drivers/neo4j/index.go b/packages/go/dawgs/drivers/neo4j/index.go index f1ba9581d2..c743c88d5f 100644 --- a/packages/go/dawgs/drivers/neo4j/index.go +++ b/packages/go/dawgs/drivers/neo4j/index.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j @@ -19,6 +19,7 @@ package neo4j import ( "context" "fmt" + "github.com/specterops/bloodhound/log" "strings" "github.com/specterops/bloodhound/dawgs/graph" @@ -28,16 +29,80 @@ const ( nativeBTreeIndexProvider = "native-btree-1.0" nativeLuceneIndexProvider = "lucene+native-3.0" - createPropertyIndexStatement = "CALL db.createIndex($name, $labels, $properties, $provider);" - createPropertyConstraintStatement = "CALL db.createUniquePropertyConstraint($name, $labels, $properties, $provider)" + dropPropertyIndexStatement = "drop index $name;" + dropPropertyConstraintStatement = "drop constraint $name;" + createPropertyIndexStatement = "call db.createIndex($name, $labels, $properties, $provider);" + createPropertyConstraintStatement = "call db.createUniquePropertyConstraint($name, $labels, $properties, $provider);" ) +type neo4jIndex struct { + graph.Index + + kind graph.Kind +} + +type neo4jConstraint struct { + graph.Constraint + + kind graph.Kind +} + +type neo4jSchema struct { + Indexes map[string]neo4jIndex + Constraints map[string]neo4jConstraint +} + +func newNeo4jSchema() neo4jSchema { + return neo4jSchema{ + Indexes: map[string]neo4jIndex{}, + Constraints: map[string]neo4jConstraint{}, + } +} + +func toNeo4jSchema(dbSchema graph.Schema) neo4jSchema { + neo4jSchemaInst := newNeo4jSchema() + + for _, graphSchema := range dbSchema.Graphs { + for _, index := range graphSchema.NodeIndexes { + for _, kind := range graphSchema.Nodes { + indexName := strings.ToLower(kind.String()) + "_" + strings.ToLower(index.Field) + "_index" + + neo4jSchemaInst.Indexes[indexName] = neo4jIndex{ + Index: graph.Index{ + Name: indexName, + Field: index.Field, + Type: index.Type, + }, + kind: kind, + } + } + } + + for _, constraint := range graphSchema.NodeConstraints { + for _, kind := range graphSchema.Nodes { + constraintName := strings.ToLower(kind.String()) + "_" + strings.ToLower(constraint.Field) + "_constraint" + + neo4jSchemaInst.Constraints[constraintName] = neo4jConstraint{ + Constraint: graph.Constraint{ + Name: constraintName, + Field: constraint.Field, + Type: constraint.Type, + }, + kind: kind, + } + } + } + } + + return neo4jSchemaInst +} + func parseProviderType(provider string) graph.IndexType { switch provider { case nativeBTreeIndexProvider: return graph.BTreeIndex case nativeLuceneIndexProvider: - return graph.FullTextSearchIndex + return graph.TextSearchIndex default: return graph.UnsupportedIndex } @@ -47,197 +112,174 @@ func indexTypeProvider(indexType graph.IndexType) string { switch indexType { case graph.BTreeIndex: return nativeBTreeIndexProvider - case graph.FullTextSearchIndex: + case graph.TextSearchIndex: return nativeLuceneIndexProvider default: return "" } } -func AssertNodePropertyIndex(db graph.Database, kind graph.Kind, propertyName string, indexType graph.IndexType) error { - return db.WriteTransaction(context.Background(), func(tx graph.Transaction) error { - statement := strings.Builder{} - - if indexType != graph.BTreeIndex { - statement.WriteString("create ") - statement.WriteString(indexTypeProvider(indexType)) - statement.WriteString(" index ") - } else { - statement.WriteString("create index ") - } +func assertIndexes(ctx context.Context, db graph.Database, indexesToRemove []string, indexesToAdd map[string]neo4jIndex) error { + if err := db.WriteTransaction(ctx, func(tx graph.Transaction) error { + for _, indexToRemove := range indexesToRemove { + log.Infof("Removing index %s", indexToRemove) - statement.WriteString(strings.ToLower(kind.String())) - statement.WriteString("_") - statement.WriteString(strings.ToLower(propertyName)) - statement.WriteString("_") - statement.WriteString(indexType.String()) - statement.WriteString(" if not exists for (n:") - statement.WriteString(kind.String()) - statement.WriteString(") on (n.") - statement.WriteString(propertyName) - statement.WriteString(");") - - if result := tx.Run(statement.String(), nil); result.Error() != nil { - return result.Error() - } else { + result := tx.Raw(strings.Replace(dropPropertyIndexStatement, "$name", indexToRemove, 1), nil) result.Close() + + if err := result.Error(); err != nil { + return err + } } return nil - }) -} - -func formatDropSchemaCypherStmts(indexSchemas map[string]graph.IndexSchema, constraintSchemas map[string]graph.ConstraintSchema) []string { - var ( - cypherStatements []string - builder strings.Builder - ) - - for _, propertyIndexSchema := range indexSchemas { - builder.WriteString("drop index ") - builder.WriteString(propertyIndexSchema.Name) - builder.WriteString(";") - - cypherStatements = append(cypherStatements, builder.String()) - builder.Reset() + }); err != nil { + return err } - for _, propertyConstraintSchema := range constraintSchemas { - builder.WriteString("drop constraint ") - builder.WriteString(propertyConstraintSchema.Name) - builder.WriteString(";") + return db.WriteTransaction(ctx, func(tx graph.Transaction) error { + for indexName, indexToAdd := range indexesToAdd { + log.Infof("Adding index %s to labels %s on properties %s using %s", indexName, indexToAdd.kind.String(), indexToAdd.Field, indexTypeProvider(indexToAdd.Type)) - cypherStatements = append(cypherStatements, builder.String()) - builder.Reset() - } + if err := db.Run(ctx, createPropertyIndexStatement, map[string]interface{}{ + "name": indexName, + "labels": []string{indexToAdd.kind.String()}, + "properties": []string{indexToAdd.Field}, + "provider": indexTypeProvider(indexToAdd.Type), + }); err != nil { + return err + } + } - return cypherStatements + return nil + }) } -func assertAgainst(ctx context.Context, requiredSchema, existingSchema *graph.Schema, db graph.Database) error { - var ( - createConstraints = func(requiredKindSchema *graph.KindSchema, constraints map[string]graph.ConstraintSchema) error { - for property, constraintToCreate := range constraints { - if err := db.Run(ctx, createPropertyConstraintStatement, map[string]interface{}{ - "name": constraintToCreate.Name, - "labels": []string{requiredKindSchema.Name()}, - "properties": []string{property}, - "provider": indexTypeProvider(constraintToCreate.IndexType), - }); err != nil { - return err - } - } +func assertConstraints(ctx context.Context, db graph.Database, constraintsToRemove []string, constraintsToAdd map[string]neo4jConstraint) error { + for _, constraintToRemove := range constraintsToRemove { + if err := db.Run(ctx, strings.Replace(dropPropertyConstraintStatement, "$name", constraintToRemove, 1), nil); err != nil { + return err + } + } - return nil + for constraintName, constraintToAdd := range constraintsToAdd { + if err := db.Run(ctx, createPropertyConstraintStatement, map[string]interface{}{ + "name": constraintName, + "labels": []string{constraintToAdd.kind.String()}, + "properties": []string{constraintToAdd.Field}, + "provider": indexTypeProvider(constraintToAdd.Type), + }); err != nil { + return err } + } - createIndices = func(requiredKindSchema *graph.KindSchema, indices map[string]graph.IndexSchema) error { - for property, indexToCreate := range indices { - if err := db.Run(ctx, createPropertyIndexStatement, map[string]interface{}{ - "name": indexToCreate.Name, - "labels": []string{requiredKindSchema.Name()}, - "properties": []string{property}, - "provider": indexTypeProvider(indexToCreate.IndexType), - }); err != nil { - return err - } - } + return nil +} - return nil - } - ) +func fetchPresentSchema(ctx context.Context, db graph.Database) (neo4jSchema, error) { + presentSchema := newNeo4jSchema() - for _, kindSchema := range existingSchema.Kinds { - if requiredKindSchema, hasMatchingDefinition := requiredSchema.Kinds[kindSchema.Kind]; !hasMatchingDefinition { - // Remove all schematic definitions for the kind since there's no matching requirement - for _, dropStmt := range formatDropSchemaCypherStmts(kindSchema.PropertyIndices, kindSchema.PropertyConstraints) { - if err := db.Run(ctx, dropStmt, nil); err != nil { - return err - } - } + return presentSchema, db.ReadTransaction(ctx, func(tx graph.Transaction) error { + if result := tx.Raw("call db.indexes() yield name, uniqueness, provider, labelsOrTypes, properties;", nil); result.Error() != nil { + return result.Error() } else { + defer result.Close() + var ( - indicesToAdd = map[string]graph.IndexSchema{} - indicesToRemove = map[string]graph.IndexSchema{} - constraintsToAdd = map[string]graph.ConstraintSchema{} - constraintsToRemove = map[string]graph.ConstraintSchema{} + name string + uniqueness string + provider string + labels []string + properties []string ) - // Match existing schematics to the definitions first - for property, indexSchema := range kindSchema.PropertyIndices { - if requiredIndexSchema, hasMatchingDefinition := requiredKindSchema.PropertyIndices[property]; !hasMatchingDefinition { - // If there's no matching index for this property defined, remove it from the database - indicesToRemove[property] = indexSchema - } else if !indexSchema.Equals(requiredIndexSchema) { - // The existing index does not match the requirement properties, recreate it - indicesToRemove[property] = indexSchema - indicesToAdd[property] = requiredIndexSchema + for result.Next() { + if err := result.Scan(&name, &uniqueness, &provider, &labels, &properties); err != nil { + return err } - } - // Sweep required schematics to ensure that missing entries are created - for property, requiredIndexSchema := range requiredKindSchema.PropertyIndices { - if _, hasMatchingDefinition := kindSchema.PropertyIndices[property]; !hasMatchingDefinition { - // If there's no matching index for this property defined, create it - indicesToAdd[property] = requiredIndexSchema + // Need this for neo4j 4.4+ which creates a weird index by default + if len(labels) == 0 { + continue } - } - for property, constraintSchema := range kindSchema.PropertyConstraints { - if requiredConstraintSchema, hasMatchingDefinition := requiredKindSchema.PropertyConstraints[property]; !hasMatchingDefinition { - // If there's no matching constraint for this property defined, remove it from the database - constraintsToRemove[property] = constraintSchema - } else if !constraintSchema.Equals(requiredConstraintSchema) { - // The existing constraint does not match the requirement properties, recreate it - constraintsToRemove[property] = constraintSchema - constraintsToAdd[property] = requiredConstraintSchema + if len(labels) > 1 || len(properties) > 1 { + return fmt.Errorf("composite index types are currently not supported") } - } - for property, constraintSchema := range requiredKindSchema.PropertyConstraints { - if _, hasMatchingDefinition := kindSchema.PropertyConstraints[property]; !hasMatchingDefinition { - // If there's no matching constraint for this property defined, create it - constraintsToAdd[property] = constraintSchema + if uniqueness == "UNIQUE" { + presentSchema.Constraints[name] = neo4jConstraint{ + Constraint: graph.Constraint{ + Name: name, + Field: properties[0], + Type: parseProviderType(provider), + }, + kind: graph.StringKind(labels[0]), + } + } else { + presentSchema.Indexes[name] = neo4jIndex{ + Index: graph.Index{ + Name: name, + Field: properties[0], + Type: parseProviderType(provider), + }, + kind: graph.StringKind(labels[0]), + } } } - // Drop all indices and constraints first - for _, dropStmt := range formatDropSchemaCypherStmts(indicesToRemove, constraintsToRemove) { - if err := db.Run(ctx, dropStmt, nil); err != nil { - return err - } - } + return result.Error() + } + }) +} - if err := createIndices(requiredKindSchema, indicesToAdd); err != nil { - return err +func assertSchema(ctx context.Context, db graph.Database, required graph.Schema) error { + requiredNeo4jSchema := toNeo4jSchema(required) + + if presentNeo4jSchema, err := fetchPresentSchema(ctx, db); err != nil { + return err + } else { + var ( + indexesToRemove []string + constraintsToRemove []string + indexesToAdd = map[string]neo4jIndex{} + constraintsToAdd = map[string]neo4jConstraint{} + ) + + for presentIndexName := range presentNeo4jSchema.Indexes { + if _, hasMatchingDefinition := requiredNeo4jSchema.Indexes[presentIndexName]; !hasMatchingDefinition { + indexesToRemove = append(indexesToRemove, presentIndexName) } + } - if err := createConstraints(requiredKindSchema, constraintsToAdd); err != nil { - return err + for presentConstraintName := range presentNeo4jSchema.Constraints { + if _, hasMatchingDefinition := requiredNeo4jSchema.Constraints[presentConstraintName]; !hasMatchingDefinition { + constraintsToRemove = append(constraintsToRemove, presentConstraintName) } } - } - for _, requiredKindSchema := range requiredSchema.Kinds { - if _, hasMatchingDefinition := existingSchema.Kinds[requiredKindSchema.Kind]; !hasMatchingDefinition { - // There's no matching definitions for indices or constraints for the required kind. Create them. - if err := createIndices(requiredKindSchema, requiredKindSchema.PropertyIndices); err != nil { - return err + for requiredIndexName, requiredIndex := range requiredNeo4jSchema.Indexes { + if presentIndex, hasMatchingDefinition := presentNeo4jSchema.Indexes[requiredIndexName]; !hasMatchingDefinition { + indexesToAdd[requiredIndexName] = requiredIndex + } else if requiredIndex.Type != presentIndex.Type { + indexesToRemove = append(indexesToRemove, requiredIndexName) + indexesToAdd[requiredIndexName] = requiredIndex } + } - if err := createConstraints(requiredKindSchema, requiredKindSchema.PropertyConstraints); err != nil { - return err + for requiredConstraintName, requiredConstraint := range requiredNeo4jSchema.Constraints { + if presentConstraint, hasMatchingDefinition := presentNeo4jSchema.Constraints[requiredConstraintName]; !hasMatchingDefinition { + constraintsToAdd[requiredConstraintName] = requiredConstraint + } else if requiredConstraint.Type != presentConstraint.Type { + constraintsToRemove = append(constraintsToRemove, requiredConstraintName) + constraintsToAdd[requiredConstraintName] = requiredConstraint } } - } - return nil -} + if err := assertConstraints(ctx, db, constraintsToRemove, constraintsToAdd); err != nil { + return err + } -func AssertSchema(ctx context.Context, db graph.Database, desiredSchema *graph.Schema) error { - if existingSchema, err := db.FetchSchema(ctx); err != nil { - return fmt.Errorf("could not load schema: %w", err) - } else { - return assertAgainst(ctx, desiredSchema, existingSchema, db) + return assertIndexes(ctx, db, indexesToRemove, indexesToAdd) } } diff --git a/packages/go/dawgs/drivers/neo4j/mapper.go b/packages/go/dawgs/drivers/neo4j/mapper.go new file mode 100644 index 0000000000..23d40f1021 --- /dev/null +++ b/packages/go/dawgs/drivers/neo4j/mapper.go @@ -0,0 +1,98 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package neo4j + +import ( + "fmt" + "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" + "github.com/specterops/bloodhound/dawgs/graph" + "time" +) + +func AsTime(value any) (time.Time, error) { + switch typedValue := value.(type) { + case dbtype.Time: + return typedValue.Time(), nil + + case dbtype.LocalTime: + return typedValue.Time(), nil + + case dbtype.Date: + return typedValue.Time(), nil + + case dbtype.LocalDateTime: + return typedValue.Time(), nil + + default: + return graph.AsTime(value) + } +} + +func mapValue(rawValue, target any) (bool, error) { + switch typedTarget := target.(type) { + case *time.Time: + if value, err := AsTime(rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *dbtype.Relationship: + if value, typeOK := rawValue.(dbtype.Relationship); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Relationship", rawValue) + } else { + *typedTarget = value + } + + case *graph.Relationship: + if value, typeOK := rawValue.(dbtype.Relationship); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Relationship", rawValue) + } else { + *typedTarget = *newRelationship(value) + } + + case *dbtype.Node: + if value, typeOK := rawValue.(dbtype.Node); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Node", rawValue) + } else { + *typedTarget = value + } + + case *graph.Node: + if value, typeOK := rawValue.(dbtype.Node); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Node", rawValue) + } else { + *typedTarget = *newNode(value) + } + + case *graph.Path: + if value, typeOK := rawValue.(dbtype.Path); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Path", rawValue) + } else { + *typedTarget = newPath(value) + } + + default: + return false, nil + } + + return true, nil +} + +func NewValueMapper(values []any) graph.ValueMapper { + return graph.NewValueMapper(values, mapValue) +} diff --git a/packages/go/dawgs/drivers/neo4j/neo4j.go b/packages/go/dawgs/drivers/neo4j/neo4j.go index 3c1349ffad..b1b2cc60a9 100644 --- a/packages/go/dawgs/drivers/neo4j/neo4j.go +++ b/packages/go/dawgs/drivers/neo4j/neo4j.go @@ -17,6 +17,7 @@ package neo4j import ( + "context" "fmt" "math" "net/url" @@ -33,7 +34,7 @@ const ( defaultNeo4jTransactionTimeout = math.MinInt ) -func newNeo4jDB(cfg dawgs.Config) (graph.Database, error) { +func newNeo4jDB(ctx context.Context, cfg dawgs.Config) (graph.Database, error) { if connectionURLStr, typeOK := cfg.DriverCfg.(string); !typeOK { return nil, fmt.Errorf("expected string for configuration type but got %T", cfg.DriverCfg) } else if connectionURL, err := url.Parse(connectionURLStr); err != nil { @@ -61,7 +62,7 @@ func newNeo4jDB(cfg dawgs.Config) (graph.Database, error) { } func init() { - dawgs.Register(DriverName, func(cfg dawgs.Config) (graph.Database, error) { - return newNeo4jDB(cfg) + dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (graph.Database, error) { + return newNeo4jDB(ctx, cfg) }) } diff --git a/packages/go/dawgs/drivers/neo4j/node.go b/packages/go/dawgs/drivers/neo4j/node.go index 5aac33b2e3..1cb8c464bc 100644 --- a/packages/go/dawgs/drivers/neo4j/node.go +++ b/packages/go/dawgs/drivers/neo4j/node.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j @@ -64,10 +64,10 @@ func NewNodeQuery(ctx context.Context, tx innerTransaction) graph.NodeQuery { } func (s *NodeQuery) run(statement string, parameters map[string]any) graph.Result { - return s.tx.Run(statement, parameters) + return s.tx.Raw(statement, parameters) } -func (s *NodeQuery) Execute(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { +func (s *NodeQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { for _, criteria := range finalCriteria { s.queryBuilder.Apply(criteria) } @@ -131,7 +131,7 @@ func (s *NodeQuery) Filterf(criteriaDelegate graph.CriteriaProvider) graph.NodeQ func (s *NodeQuery) Count() (int64, error) { var count int64 - return count, s.Execute(func(results graph.Result) error { + return count, s.Query(func(results graph.Result) error { if !results.Next() { return graph.ErrNoResultsFound } @@ -171,7 +171,7 @@ func (s *NodeQuery) Update(properties *graph.Properties) error { func (s *NodeQuery) First() (*graph.Node, error) { var node graph.Node - return &node, s.Execute(func(results graph.Result) error { + return &node, s.Query(func(results graph.Result) error { if !results.Next() { return graph.ErrNoResultsFound } @@ -183,7 +183,7 @@ func (s *NodeQuery) First() (*graph.Node, error) { } func (s *NodeQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Node]) error) error { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (*graph.Node, error) { var node graph.Node return &node, scanner.Scan(&node) @@ -197,7 +197,7 @@ func (s *NodeQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Node]) error) } func (s *NodeQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.ID, error) { var nodeID graph.ID return nodeID, scanner.Scan(&nodeID) @@ -211,7 +211,7 @@ func (s *NodeQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) } func (s *NodeQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.KindsResult]) error) error { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.KindsResult, error) { var ( nodeID graph.ID diff --git a/packages/go/dawgs/drivers/neo4j/relationship.go b/packages/go/dawgs/drivers/neo4j/relationship.go index fccac1cc8f..f30bb87ed8 100644 --- a/packages/go/dawgs/drivers/neo4j/relationship.go +++ b/packages/go/dawgs/drivers/neo4j/relationship.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j @@ -29,14 +29,14 @@ import ( func directionToReturnCriteria(direction graph.Direction) (graph.Criteria, error) { switch direction { case graph.DirectionInbound: - // Return the relationship and the end node + // Select the relationship and the end node return query.Returning( query.Relationship(), query.End(), ), nil case graph.DirectionOutbound: - // Return the relationship and the start node + // Select the relationship and the start node return query.Returning( query.Relationship(), query.Start(), @@ -78,10 +78,10 @@ func NewRelationshipQuery(ctx context.Context, tx innerTransaction) graph.Relati } func (s *RelationshipQuery) run(statement string, parameters map[string]any) graph.Result { - return s.tx.Run(statement, parameters) + return s.tx.Raw(statement, parameters) } -func (s *RelationshipQuery) Execute(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { +func (s *RelationshipQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { for _, criteria := range finalCriteria { s.queryBuilder.Apply(criteria) } @@ -169,7 +169,7 @@ func (s *RelationshipQuery) Filterf(criteriaDelegate graph.CriteriaProvider) gra func (s *RelationshipQuery) Count() (int64, error) { var count int64 - return count, s.Execute(func(results graph.Result) error { + return count, s.Query(func(results graph.Result) error { if !results.Next() { return graph.ErrNoResultsFound } @@ -209,7 +209,7 @@ func (s *RelationshipQuery) FetchAllShortestPaths(delegate func(cursor graph.Cur } func (s *RelationshipQuery) FetchTriples(delegate func(cursor graph.Cursor[graph.RelationshipTripleResult]) error) error { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.RelationshipTripleResult, error) { var ( startID graph.ID @@ -227,15 +227,15 @@ func (s *RelationshipQuery) FetchTriples(delegate func(cursor graph.Cursor[graph defer cursor.Close() return delegate(cursor) - }, query.Returning( - query.Distinct(query.StartID()), + }, query.ReturningDistinct( + query.StartID(), query.RelationshipID(), query.EndID(), )) } func (s *RelationshipQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.RelationshipKindsResult]) error) error { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.RelationshipKindsResult, error) { var ( startID graph.ID @@ -268,7 +268,7 @@ func (s *RelationshipQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.R func (s *RelationshipQuery) First() (*graph.Relationship, error) { var relationship graph.Relationship - return &relationship, s.Execute(func(results graph.Result) error { + return &relationship, s.Query(func(results graph.Result) error { if !results.Next() { return graph.ErrNoResultsFound } @@ -280,7 +280,7 @@ func (s *RelationshipQuery) First() (*graph.Relationship, error) { } func (s *RelationshipQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Relationship]) error) error { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (*graph.Relationship, error) { var relationship graph.Relationship return &relationship, scanner.Scan(&relationship) @@ -297,7 +297,7 @@ func (s *RelationshipQuery) FetchDirection(direction graph.Direction, delegate f if returnCriteria, err := directionToReturnCriteria(direction); err != nil { return err } else { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.DirectionalResult, error) { var ( relationship graph.Relationship @@ -322,7 +322,7 @@ func (s *RelationshipQuery) FetchDirection(direction graph.Direction, delegate f } func (s *RelationshipQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { - return s.Execute(func(result graph.Result) error { + return s.Query(func(result graph.Result) error { cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.ID, error) { var relationshipID graph.ID return relationshipID, scanner.Scan(&relationshipID) diff --git a/packages/go/dawgs/drivers/neo4j/result.go b/packages/go/dawgs/drivers/neo4j/result.go index 00b9affee8..76c4b1de94 100644 --- a/packages/go/dawgs/drivers/neo4j/result.go +++ b/packages/go/dawgs/drivers/neo4j/result.go @@ -1,468 +1,26 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j import ( - "fmt" - "time" - "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" "github.com/specterops/bloodhound/dawgs/graph" ) -func asUint8(value any) (uint8, error) { - switch typedValue := value.(type) { - case uint8: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to uint8", value) - } -} - -func asUint16(value any) (uint16, error) { - switch typedValue := value.(type) { - case uint8: - return uint16(typedValue), nil - case uint16: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to uint16", value) - } -} - -func asUint32(value any) (uint32, error) { - switch typedValue := value.(type) { - case uint8: - return uint32(typedValue), nil - case uint16: - return uint32(typedValue), nil - case uint32: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to uint32", value) - } -} - -func asUint64(value any) (uint64, error) { - switch typedValue := value.(type) { - case uint: - return uint64(typedValue), nil - case uint8: - return uint64(typedValue), nil - case uint16: - return uint64(typedValue), nil - case uint32: - return uint64(typedValue), nil - case uint64: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to uint64", value) - } -} - -func asUint(value any) (uint, error) { - switch typedValue := value.(type) { - case uint: - return typedValue, nil - case uint8: - return uint(typedValue), nil - case uint16: - return uint(typedValue), nil - case uint32: - return uint(typedValue), nil - case uint64: - return uint(typedValue), nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to uint", value) - } -} - -func asInt8(value any) (int8, error) { - switch typedValue := value.(type) { - case int8: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to int8", value) - } -} - -func asInt16(value any) (int16, error) { - switch typedValue := value.(type) { - case int8: - return int16(typedValue), nil - case int16: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to int16", value) - } -} - -func asInt32(value any) (int32, error) { - switch typedValue := value.(type) { - case int8: - return int32(typedValue), nil - case int16: - return int32(typedValue), nil - case int32: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to int32", value) - } -} - -func asInt64(value any) (int64, error) { - switch typedValue := value.(type) { - case graph.ID: - return int64(typedValue), nil - case int: - return int64(typedValue), nil - case int8: - return int64(typedValue), nil - case int16: - return int64(typedValue), nil - case int32: - return int64(typedValue), nil - case int64: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to int64", value) - } -} - -func asInt(value any) (int, error) { - switch typedValue := value.(type) { - case int: - return typedValue, nil - case int8: - return int(typedValue), nil - case int16: - return int(typedValue), nil - case int32: - return int(typedValue), nil - case int64: - return int(typedValue), nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to int", value) - } -} - -func asFloat32(value any) (float32, error) { - switch typedValue := value.(type) { - case float32: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to int64", value) - } -} - -func asFloat64(value any) (float64, error) { - switch typedValue := value.(type) { - case float32: - return float64(typedValue), nil - case float64: - return typedValue, nil - default: - return 0, fmt.Errorf("unexecpted type %T will not negotiate to int64", value) - } -} - -func asTime(value any) (time.Time, error) { - switch typedValue := value.(type) { - case string: - if parsedTime, err := time.Parse(time.RFC3339Nano, typedValue); err != nil { - return time.Time{}, err - } else { - return parsedTime, nil - } - - case dbtype.Time: - return typedValue.Time(), nil - - case dbtype.LocalTime: - return typedValue.Time(), nil - - case dbtype.Date: - return typedValue.Time(), nil - - case dbtype.LocalDateTime: - return typedValue.Time(), nil - - case float64: - return time.Unix(int64(typedValue), 0), nil - - case int64: - return time.Unix(typedValue, 0), nil - - case time.Time: - return typedValue, nil - - default: - return time.Time{}, fmt.Errorf("unexecpted type %T will not negotiate to time.Time", value) - } -} - -func mapValue(target, rawValue any) error { - switch typedTarget := target.(type) { - case *uint: - if value, err := asUint(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *uint8: - if value, err := asUint8(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *uint16: - if value, err := asUint16(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *uint32: - if value, err := asUint32(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *uint64: - if value, err := asUint64(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *int: - if value, err := asInt(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *int8: - if value, err := asInt8(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *int16: - if value, err := asInt16(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *int32: - if value, err := asInt32(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *int64: - if value, err := asInt64(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *graph.ID: - if value, err := asInt64(rawValue); err != nil { - return err - } else { - *typedTarget = graph.ID(value) - } - - case *float32: - if value, err := asFloat32(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *float64: - if value, err := asFloat64(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *bool: - if value, typeOK := rawValue.(bool); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to bool", rawValue) - } else { - *typedTarget = value - } - - case *graph.Kind: - if strValue, typeOK := rawValue.(string); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to string", rawValue) - } else { - *typedTarget = graph.StringKind(strValue) - } - - case *string: - if value, typeOK := rawValue.(string); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to string", rawValue) - } else { - *typedTarget = value - } - - case *[]graph.Kind: - if rawValues, typeOK := rawValue.([]any); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to []any", rawValue) - } else if kindValues, err := anySliceToKinds(rawValues); err != nil { - return err - } else { - *typedTarget = kindValues - } - - case *graph.Kinds: - if rawValues, typeOK := rawValue.([]any); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to []any", rawValue) - } else if kindValues, err := anySliceToKinds(rawValues); err != nil { - return err - } else { - *typedTarget = kindValues - } - - case *[]string: - if rawValues, typeOK := rawValue.([]any); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to []any", rawValue) - } else if stringValues, err := anySliceToStringSlice(rawValues); err != nil { - return err - } else { - *typedTarget = stringValues - } - - case *time.Time: - if value, err := asTime(rawValue); err != nil { - return err - } else { - *typedTarget = value - } - - case *dbtype.Relationship: - if value, typeOK := rawValue.(dbtype.Relationship); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Relationship", rawValue) - } else { - *typedTarget = value - } - - case *graph.Relationship: - if value, typeOK := rawValue.(dbtype.Relationship); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Relationship", rawValue) - } else { - *typedTarget = *newRelationship(value) - } - - case *dbtype.Node: - if value, typeOK := rawValue.(dbtype.Node); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Node", rawValue) - } else { - *typedTarget = value - } - - case *graph.Node: - if value, typeOK := rawValue.(dbtype.Node); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Node", rawValue) - } else { - *typedTarget = *newNode(value) - } - - case *graph.Path: - if value, typeOK := rawValue.(dbtype.Path); !typeOK { - return fmt.Errorf("unexecpted type %T will not negotiate to *dbtype.Path", rawValue) - } else { - *typedTarget = newPath(value) - } - - default: - return fmt.Errorf("unsupported scan type %T", target) - } - - return nil -} - -type ValueMapper struct { - values []any - idx int -} - -func NewValueMapper(values []any) *ValueMapper { - return &ValueMapper{ - values: values, - idx: 0, - } -} - -func (s *ValueMapper) Next() (any, error) { - if s.idx >= len(s.values) { - return nil, fmt.Errorf("attempting to get more values than returned - saw %d but wanted %d", len(s.values), s.idx+1) - } - - nextValue := s.values[s.idx] - s.idx++ - - return nextValue, nil -} - -func (s *ValueMapper) Map(target any) error { - if rawValue, err := s.Next(); err != nil { - return err - } else { - return mapValue(target, rawValue) - } -} - -func (s *ValueMapper) MapOptions(targets ...any) (any, error) { - if rawValue, err := s.Next(); err != nil { - return nil, err - } else { - for _, target := range targets { - if mapValue(target, rawValue) == nil { - return target, nil - } - } - - return nil, fmt.Errorf("no matching target given for type: %T", rawValue) - } -} - -func (s *ValueMapper) Scan(targets ...any) error { - for idx, mapValue := range targets { - if err := s.Map(mapValue); err != nil { - return err - } else { - targets[idx] = mapValue - } - } - - return nil -} - type internalResult struct { query string err error @@ -477,37 +35,18 @@ func NewResult(query string, err error, driverResult neo4j.Result) graph.Result } } -func anySliceToStringSlice(rawValues []any) ([]string, error) { - strings := make([]string, len(rawValues)) - - for idx, rawValue := range rawValues { - switch typedValue := rawValue.(type) { - case string: - strings[idx] = typedValue - default: - return nil, fmt.Errorf("unexpected type %T will not negotiate to string", rawValue) - } - } - - return strings, nil +func (s *internalResult) Values() (graph.ValueMapper, error) { + return NewValueMapper(s.driverResult.Record().Values), nil } -func anySliceToKinds(rawValues []any) (graph.Kinds, error) { - if stringValues, err := anySliceToStringSlice(rawValues); err != nil { - return nil, err +func (s *internalResult) Scan(targets ...any) error { + if values, err := s.Values(); err != nil { + return err } else { - return graph.StringsToKinds(stringValues), nil + return values.Scan(targets...) } } -func (s *internalResult) Values() graph.ValueMapper { - return NewValueMapper(s.driverResult.Record().Values) -} - -func (s *internalResult) Scan(targets ...any) error { - return s.Values().Scan(targets...) -} - func (s *internalResult) Next() bool { return s.driverResult.Next() } @@ -525,6 +64,8 @@ func (s *internalResult) Error() error { } func (s *internalResult) Close() { - // Ignore the results of this call. This is called only as a best-effort attempt at a close - s.driverResult.Consume() + if s.driverResult != nil { + // Ignore the results of this call. This is called only as a best-effort attempt at a close + s.driverResult.Consume() + } } diff --git a/packages/go/dawgs/drivers/neo4j/result_internal_test.go b/packages/go/dawgs/drivers/neo4j/result_internal_test.go index f7f91015e5..68b32cc397 100644 --- a/packages/go/dawgs/drivers/neo4j/result_internal_test.go +++ b/packages/go/dawgs/drivers/neo4j/result_internal_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j @@ -21,14 +21,17 @@ import ( "time" "github.com/neo4j/neo4j-go-driver/v5/neo4j/dbtype" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/dawgs/graph" + "github.com/stretchr/testify/require" ) func mapTestCase[T, V any](t *testing.T, source T, expected V) { - var value V + var ( + mapper = graph.NewValueMapper([]any{source}, mapValue) + value V + ) - require.Nil(t, mapValue(&value, source)) + require.Nil(t, mapper.Map(&value)) require.Equalf(t, expected, value, "Mapping case for type %T to %T failed. Value is: %v", source, &value, value) } diff --git a/packages/go/dawgs/drivers/neo4j/result_test.go b/packages/go/dawgs/drivers/neo4j/result_test.go index e8d91b4ccd..3a97b93559 100644 --- a/packages/go/dawgs/drivers/neo4j/result_test.go +++ b/packages/go/dawgs/drivers/neo4j/result_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j_test @@ -19,8 +19,8 @@ package neo4j_test import ( "testing" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/dawgs/drivers/neo4j" + "github.com/stretchr/testify/require" ) func TestValueMapper_MapOptions(t *testing.T) { @@ -32,16 +32,15 @@ func TestValueMapper_MapOptions(t *testing.T) { stringOption string mappedPointer, err = mapper.MapOptions(&floatOption, &intOption) - _, isIntPointer = mappedPointer.(*int) + _, isFloatPointer = mappedPointer.(*float64) ) require.Nil(t, err) - require.True(t, isIntPointer) - require.Equal(t, 1, intOption) + require.True(t, isFloatPointer) + require.Equal(t, float64(1), floatOption) mappedPointer, err = mapper.MapOptions(&stringOption) require.Nil(t, mappedPointer) - require.ErrorContains(t, err, "no matching target given") - + require.ErrorContains(t, err, "no matching target given for type: int") } diff --git a/packages/go/dawgs/drivers/neo4j/transaction.go b/packages/go/dawgs/drivers/neo4j/transaction.go index 1506dc51f0..394c7dcbe2 100644 --- a/packages/go/dawgs/drivers/neo4j/transaction.go +++ b/packages/go/dawgs/drivers/neo4j/transaction.go @@ -20,7 +20,8 @@ import ( "context" "encoding/json" "fmt" - "log" + "github.com/specterops/bloodhound/dawgs/drivers" + "github.com/specterops/bloodhound/log" "sort" "strings" @@ -41,7 +42,7 @@ const ( ) type innerTransaction interface { - Run(cypher string, params map[string]any) graph.Result + Raw(cypher string, params map[string]any) graph.Result } type neo4jTransaction struct { @@ -55,6 +56,19 @@ type neo4jTransaction struct { traversalMemoryLimit size.Size } +func (s *neo4jTransaction) WithGraph(graphSchema graph.Graph) graph.Transaction { + // Neo4j does not support multiple graph namespaces within the same database. While Neo4j enterprise supports + // multiple databases this is not the same. Graph namespaces could be hacked using labels but this then requires + // a material change in how labels are applied and therefore was not plumbed. + // + // This has no material effect on the usage of the database: the schema is the same for all graph namespaces. + return s +} + +func (s *neo4jTransaction) Query(query string, parameters map[string]any) graph.Result { + return s.Raw(query, parameters) +} + func (s *neo4jTransaction) updateRelationshipsBy(updates ...graph.RelationshipUpdate) error { var ( numUpdates = len(updates) @@ -69,17 +83,18 @@ func (s *neo4jTransaction) updateRelationshipsBy(updates ...graph.RelationshipUp chunkMap = append(chunkMap, val) if len(chunkMap) == s.batchWriteSize { - if result := s.Run(stmt, map[string]any{ + if result := s.Raw(stmt, map[string]any{ "p": chunkMap, }); result.Error() != nil { return result.Error() } + chunkMap = chunkMap[:0] } } if len(chunkMap) > 0 { - if result := s.Run(stmt, map[string]any{ + if result := s.Raw(stmt, map[string]any{ "p": chunkMap, }); result.Error() != nil { return result.Error() @@ -101,7 +116,7 @@ func (s *neo4jTransaction) updateNodesBy(updates ...graph.NodeUpdate) error { ) for parameterIdx, stmt := range statements { - if result := s.Run(stmt, queryParameterMaps[parameterIdx]); result.Error() != nil { + if result := s.Raw(stmt, queryParameterMaps[parameterIdx]); result.Error() != nil { return fmt.Errorf("update nodes by error on statement (%s): %s", stmt, result.Error()) } } @@ -165,7 +180,7 @@ func (s *neo4jTransaction) logWrites(writes int) error { } func (s *neo4jTransaction) runAndLog(stmt string, params map[string]any, numWrites int) graph.Result { - result := s.Run(stmt, params) + result := s.Raw(stmt, params) if result.Error() == nil { if err := s.logWrites(numWrites); err != nil { @@ -212,7 +227,7 @@ func (s *neo4jTransaction) updateNode(updatedNode *graph.Node) error { return err } else if cypherQuery, err := queryBuilder.Render(); err != nil { return graph.NewError(cypherQuery, err) - } else if result := s.Run(cypherQuery, queryBuilder.Parameters); result.Error() != nil { + } else if result := s.Raw(cypherQuery, queryBuilder.Parameters); result.Error() != nil { return result.Error() } @@ -237,7 +252,7 @@ func (s *neo4jTransaction) createNode(properties *graph.Properties, kinds ...gra return nil, err } else if statement, err := queryBuilder.Render(); err != nil { return nil, err - } else if result := s.Run(statement, queryBuilder.Parameters); result.Error() != nil { + } else if result := s.Raw(statement, queryBuilder.Parameters); result.Error() != nil { return nil, result.Error() } else if !result.Next() { return nil, graph.ErrNoResultsFound @@ -270,7 +285,7 @@ func (s *neo4jTransaction) createRelationshipByIDs(startNodeID, endNodeID graph. return nil, err } else if statement, err := queryBuilder.Render(); err != nil { return nil, err - } else if result := s.Run(statement, queryBuilder.Parameters); result.Error() != nil { + } else if result := s.Raw(statement, queryBuilder.Parameters); result.Error() != nil { return nil, result.Error() } else if !result.Next() { return nil, graph.ErrNoResultsFound @@ -280,10 +295,10 @@ func (s *neo4jTransaction) createRelationshipByIDs(startNodeID, endNodeID graph. } } -func (s *neo4jTransaction) Run(stmt string, params map[string]any) graph.Result { +func (s *neo4jTransaction) Raw(stmt string, params map[string]any) graph.Result { const maxParametersToRender = 12 - if IsQueryAnalysisEnabled() { + if drivers.IsQueryAnalysisEnabled() { var ( parametersWritten = 0 prettyParameters strings.Builder @@ -317,13 +332,13 @@ func (s *neo4jTransaction) Run(stmt string, params map[string]any) graph.Result prettyParameters.WriteString(":") if marshalledValue, err := json.Marshal(value); err != nil { - log.Printf("Unable to marshal query parameter %s", key) + log.Errorf("Unable to marshal query parameter %s", key) } else { prettyParameters.Write(marshalledValue) } } - log.Printf("[neo4j] %s - %s", stmt, prettyParameters.String()) + log.Info().Str("dawgs_db_driver", DriverName).Msgf("%s - %s", stmt, prettyParameters.String()) } driverResult, err := s.currentTx().Run(stmt, params) diff --git a/packages/go/dawgs/drivers/neo4j/transaction_internal_test.go b/packages/go/dawgs/drivers/neo4j/transaction_internal_test.go index c44b6b9a6b..72ca59752e 100644 --- a/packages/go/dawgs/drivers/neo4j/transaction_internal_test.go +++ b/packages/go/dawgs/drivers/neo4j/transaction_internal_test.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j @@ -20,10 +20,10 @@ import ( "testing" neo4j_core "github.com/neo4j/neo4j-go-driver/v5/neo4j" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/dawgs/vendormocks/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) var ( @@ -192,7 +192,7 @@ func TestNeo4jTransaction_UpdateRelationshipBy_Batch(t *testing.T) { // Expect that a new transaction is opened and then closed with commit after the final submission and call to commit sessionMock.EXPECT().BeginTransaction(gomock.Any()).Return(transactionMock, nil) resultMock.EXPECT().Err().Return(nil) - transactionMock.EXPECT().Run(`unwind $p as p merge (s:Base {objectid:p.s.objectid}) merge (e:Base {objectid:p.e.objectid}) merge (s)-[r:HasSession {objectid:p.r.objectid}]->(e) set r += p.r, s:Base, s:User, e:Base, e:Computer, s.lastseen = datetime({timezone: 'UTC'}), e.lastseen = datetime({timezone: 'UTC'});`, gomock.Any()).Return(resultMock, nil) + transactionMock.EXPECT().Run(`unwind $p as p merge (s:Base {objectid:p.s.objectid}) merge (e:Base {objectid:p.e.objectid}) merge (s)-[r:HasSession {objectid:p.r.objectid}]->(e) set s += p.s, e += p.e, r += p.r, s:Base, s:User, e:Base, e:Computer, s.lastseen = datetime({timezone: 'UTC'}), e.lastseen = datetime({timezone: 'UTC'});`, gomock.Any()).Return(resultMock, nil) transactionMock.EXPECT().Commit().Return(nil) submitf() @@ -225,15 +225,17 @@ func TestNeo4jTransaction_UpdateRelationshipBy(t *testing.T) { "name": "a_name", } - expectedIdentityProperties = map[string]any{ + expectedPayloa = map[string]any{ "p": []map[string]any{{ "r": map[string]any{ "objectid": "a-b-c", }, "s": map[string]any{ + "name": "a_name", "objectid": "1-2-3", }, "e": map[string]any{ + "name": "a_name", "objectid": "2-3-4", }, }, { @@ -241,9 +243,11 @@ func TestNeo4jTransaction_UpdateRelationshipBy(t *testing.T) { "objectid": "a-b-c", }, "s": map[string]any{ + "name": "a_name", "objectid": "1-2-3", }, "e": map[string]any{ + "name": "a_name", "objectid": "2-3-4", }, }}, @@ -251,7 +255,7 @@ func TestNeo4jTransaction_UpdateRelationshipBy(t *testing.T) { ) resultMock.EXPECT().Err().Return(nil) - transactionMock.EXPECT().Run(`unwind $p as p merge (s:Base {objectid:p.s.objectid}) merge (e:Base {objectid:p.e.objectid}) merge (s)-[r:HasSession {objectid:p.r.objectid}]->(e) set r += p.r, s:Base, s:User, e:Base, e:Computer, s.lastseen = datetime({timezone: 'UTC'}), e.lastseen = datetime({timezone: 'UTC'});`, expectedIdentityProperties).Return(resultMock, nil) + transactionMock.EXPECT().Run(`unwind $p as p merge (s:Base {objectid:p.s.objectid}) merge (e:Base {objectid:p.e.objectid}) merge (s)-[r:HasSession {objectid:p.r.objectid}]->(e) set s += p.s, e += p.e, r += p.r, s:Base, s:User, e:Base, e:Computer, s.lastseen = datetime({timezone: 'UTC'}), e.lastseen = datetime({timezone: 'UTC'});`, expectedPayloa).Return(resultMock, nil) transactionMock.EXPECT().Commit().Return(nil) require.Nil(t, tx.UpdateRelationshipBy(graph.RelationshipUpdate{ @@ -293,13 +297,13 @@ func TestNeo4jTransaction_CreateNode(t *testing.T) { } expectedProperties = map[string]any{ - "0": map[string]any{ + "p0": map[string]any{ "prop": "value", }, } ) - transactionMock.EXPECT().Run(`create (n:User $0) return n`, expectedProperties).Return(resultMock, nil) + transactionMock.EXPECT().Run(`create (n:User $p0) return n`, expectedProperties).Return(resultMock, nil) transactionMock.EXPECT().Commit() resultMock.EXPECT().Err().Return(nil) diff --git a/packages/go/dawgs/drivers/pg/batch.go b/packages/go/dawgs/drivers/pg/batch.go new file mode 100644 index 0000000000..315befff8b --- /dev/null +++ b/packages/go/dawgs/drivers/pg/batch.go @@ -0,0 +1,596 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "bytes" + "context" + "fmt" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/specterops/bloodhound/cypher/backend/pgsql" + "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + sql "github.com/specterops/bloodhound/dawgs/drivers/pg/query" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/log" + "strconv" + "strings" + "sync/atomic" +) + +type Int2ArrayEncoder struct { + buffer *bytes.Buffer +} + +func (s *Int2ArrayEncoder) Encode(values []int16) string { + s.buffer.Reset() + s.buffer.WriteRune('{') + + for idx, value := range values { + if idx > 0 { + s.buffer.WriteRune(',') + } + + s.buffer.WriteString(strconv.Itoa(int(value))) + } + + s.buffer.WriteRune('}') + return s.buffer.String() +} + +type batch struct { + ctx context.Context + innerTransaction *transaction + schemaManager *SchemaManager + nodeDeletionBuffer []graph.ID + relationshipDeletionBuffer []graph.ID + nodeCreateBuffer []*graph.Node + nodeUpdateByBuffer []graph.NodeUpdate + relationshipCreateBuffer []*graph.Relationship + relationshipUpdateByBuffer []graph.RelationshipUpdate + batchWriteSize int + kindIDEncoder Int2ArrayEncoder +} + +func newBatch(ctx context.Context, conn *pgxpool.Conn, schemaManager *SchemaManager, cfg *Config) (*batch, error) { + if tx, err := newTransaction(ctx, conn, schemaManager, cfg); err != nil { + return nil, err + } else { + return &batch{ + ctx: ctx, + schemaManager: schemaManager, + innerTransaction: tx, + batchWriteSize: cfg.BatchWriteSize, + kindIDEncoder: Int2ArrayEncoder{ + buffer: &bytes.Buffer{}, + }, + }, nil + } +} + +func (s *batch) WithGraph(schema graph.Graph) graph.Batch { + s.innerTransaction.WithGraph(schema) + return s +} + +func (s *batch) CreateNode(node *graph.Node) error { + s.nodeCreateBuffer = append(s.nodeCreateBuffer, node) + return s.tryFlush(s.batchWriteSize) +} + +func (s *batch) Nodes() graph.NodeQuery { + return s.innerTransaction.Nodes() +} + +func (s *batch) Relationships() graph.RelationshipQuery { + return s.innerTransaction.Relationships() +} + +func (s *batch) UpdateNodeBy(update graph.NodeUpdate) error { + s.nodeUpdateByBuffer = append(s.nodeUpdateByBuffer, update) + return s.tryFlush(s.batchWriteSize) +} + +func (s *batch) flushNodeDeleteBuffer() error { + if _, err := s.innerTransaction.tx.Exec(s.ctx, deleteNodeWithIDStatement, s.nodeDeletionBuffer); err != nil { + return err + } + + s.nodeDeletionBuffer = s.nodeDeletionBuffer[:0] + return nil +} + +func (s *batch) flushRelationshipDeleteBuffer() error { + if _, err := s.innerTransaction.tx.Exec(s.ctx, deleteEdgeWithIDStatement, s.relationshipDeletionBuffer); err != nil { + return err + } + + s.relationshipDeletionBuffer = s.relationshipDeletionBuffer[:0] + return nil +} + +func (s *batch) flushNodeCreateBuffer() error { + var ( + withoutIDs = false + withIDs = false + ) + + for _, node := range s.nodeCreateBuffer { + if node.ID == 0 || node.ID == graph.UnregisteredNodeID { + withoutIDs = true + } else { + withIDs = true + } + + if withIDs && withoutIDs { + return fmt.Errorf("batch may not mix preset node IDs with entries that require an auto-generated ID") + } + } + + if withoutIDs { + return s.flushNodeCreateBufferWithoutIDs() + } + + return s.flushNodeCreateBufferWithIDs() +} + +func (s *batch) flushNodeCreateBufferWithIDs() error { + var ( + numCreates = len(s.nodeCreateBuffer) + nodeIDs = make([]uint32, numCreates) + kindIDSlices = make([]string, numCreates) + kindIDEncoder = Int2ArrayEncoder{ + buffer: &bytes.Buffer{}, + } + properties = make([]pgtype.JSONB, numCreates) + ) + + for idx, nextNode := range s.nodeCreateBuffer { + nodeIDs[idx] = nextNode.ID.Uint32() + + if mappedKindIDs, missingKinds := s.schemaManager.MapKinds(nextNode.Kinds); len(missingKinds) > 0 { + return fmt.Errorf("unable to map kinds %v", missingKinds) + } else { + kindIDSlices[idx] = kindIDEncoder.Encode(mappedKindIDs) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(nextNode.Properties); err != nil { + return err + } else { + properties[idx] = propertiesJSONB + } + } + + if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { + return err + } else if _, err := s.innerTransaction.tx.Exec(s.ctx, createNodeWithIDBatchStatement, graphTarget.ID, nodeIDs, kindIDSlices, properties); err != nil { + return err + } + + s.nodeCreateBuffer = s.nodeCreateBuffer[:0] + return nil +} + +func (s *batch) flushNodeCreateBufferWithoutIDs() error { + var ( + numCreates = len(s.nodeCreateBuffer) + kindIDSlices = make([]string, numCreates) + kindIDEncoder = Int2ArrayEncoder{ + buffer: &bytes.Buffer{}, + } + properties = make([]pgtype.JSONB, numCreates) + ) + + for idx, nextNode := range s.nodeCreateBuffer { + if mappedKindIDs, missingKinds := s.schemaManager.MapKinds(nextNode.Kinds); len(missingKinds) > 0 { + return fmt.Errorf("unable to map kinds %v", missingKinds) + } else { + kindIDSlices[idx] = kindIDEncoder.Encode(mappedKindIDs) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(nextNode.Properties); err != nil { + return err + } else { + properties[idx] = propertiesJSONB + } + } + + if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { + return err + } else if _, err := s.innerTransaction.tx.Exec(s.ctx, createNodeWithoutIDBatchStatement, graphTarget.ID, kindIDSlices, properties); err != nil { + return err + } + + s.nodeCreateBuffer = s.nodeCreateBuffer[:0] + return nil +} + +func (s *batch) flushNodeUpsertBatch(updates *sql.NodeUpdateBatch) error { + parameters := NewNodeUpsertParameters(len(updates.Updates)) + + if err := parameters.AppendAll(updates, s.schemaManager, s.kindIDEncoder); err != nil { + return err + } + + if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { + return err + } else { + query := sql.FormatNodeUpsert(graphTarget, updates.IdentityProperties) + + if rows, err := s.innerTransaction.tx.Query(s.ctx, query, parameters.Format(graphTarget)...); err != nil { + return err + } else { + defer rows.Close() + + idFutureIndex := 0 + + for rows.Next() { + if err := rows.Scan(¶meters.IDFutures[idFutureIndex].Value); err != nil { + return err + } + + idFutureIndex++ + } + } + } + + return nil +} + +func (s *batch) tryFlushNodeUpdateByBuffer() error { + if updates, err := sql.ValidateNodeUpdateByBatch(s.nodeUpdateByBuffer); err != nil { + return err + } else if err := s.flushNodeUpsertBatch(updates); err != nil { + return err + } + + s.nodeUpdateByBuffer = s.nodeUpdateByBuffer[:0] + return nil +} + +type NodeUpsertParameters struct { + IDFutures []*sql.Future[graph.ID] + KindIDSlices []string + Properties []pgtype.JSONB +} + +func NewNodeUpsertParameters(size int) *NodeUpsertParameters { + return &NodeUpsertParameters{ + IDFutures: make([]*sql.Future[graph.ID], 0, size), + KindIDSlices: make([]string, 0, size), + Properties: make([]pgtype.JSONB, 0, size), + } +} + +func (s *NodeUpsertParameters) Format(graphTarget model.Graph) []any { + return []any{ + graphTarget.ID, + s.KindIDSlices, + s.Properties, + } +} + +func (s *NodeUpsertParameters) Append(update *sql.NodeUpdate, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + s.IDFutures = append(s.IDFutures, update.IDFuture) + + if mappedKindIDs, missingKinds := schemaManager.MapKinds(update.Node.Kinds); len(missingKinds) > 0 { + return fmt.Errorf("unable to map kinds %v", missingKinds) + } else { + s.KindIDSlices = append(s.KindIDSlices, kindIDEncoder.Encode(mappedKindIDs)) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(update.Node.Properties); err != nil { + return err + } else { + s.Properties = append(s.Properties, propertiesJSONB) + } + + return nil +} + +func (s *NodeUpsertParameters) AppendAll(updates *sql.NodeUpdateBatch, schemaManager *SchemaManager, kindIDEncoder Int2ArrayEncoder) error { + for _, nextUpdate := range updates.Updates { + if err := s.Append(nextUpdate, schemaManager, kindIDEncoder); err != nil { + return err + } + } + + return nil +} + +type RelationshipUpdateByParameters struct { + StartIDs []graph.ID + EndIDs []graph.ID + KindIDs []int16 + Properties []pgtype.JSONB +} + +func NewRelationshipUpdateByParameters(size int) *RelationshipUpdateByParameters { + return &RelationshipUpdateByParameters{ + StartIDs: make([]graph.ID, 0, size), + EndIDs: make([]graph.ID, 0, size), + KindIDs: make([]int16, 0, size), + Properties: make([]pgtype.JSONB, 0, size), + } +} + +func (s *RelationshipUpdateByParameters) Format(graphTarget model.Graph) []any { + return []any{ + graphTarget.ID, + s.StartIDs, + s.EndIDs, + s.KindIDs, + s.Properties, + } +} + +func (s *RelationshipUpdateByParameters) Append(update *sql.RelationshipUpdate, schemaManager *SchemaManager) error { + s.StartIDs = append(s.StartIDs, update.StartID.Value) + s.EndIDs = append(s.EndIDs, update.EndID.Value) + + if mappedKindID, mapped := schemaManager.MapKind(update.Relationship.Kind); !mapped { + return fmt.Errorf("unable to map kind %s", update.Relationship.Kind) + } else { + s.KindIDs = append(s.KindIDs, mappedKindID) + } + + if propertiesJSONB, err := pgsql.PropertiesToJSONB(update.Relationship.Properties); err != nil { + return err + } else { + s.Properties = append(s.Properties, propertiesJSONB) + } + return nil +} + +func (s *RelationshipUpdateByParameters) AppendAll(updates *sql.RelationshipUpdateBatch, schemaManager *SchemaManager) error { + for _, nextUpdate := range updates.Updates { + if err := s.Append(nextUpdate, schemaManager); err != nil { + return err + } + } + + return nil +} + +var numRels = &atomic.Int64{} + +func (s *batch) flushRelationshipUpdateByBuffer(updates *sql.RelationshipUpdateBatch) error { + if err := s.flushNodeUpsertBatch(updates.NodeUpdates); err != nil { + return err + } + + parameters := NewRelationshipUpdateByParameters(len(updates.Updates)) + + if err := parameters.AppendAll(updates, s.schemaManager); err != nil { + return err + } + + numRels.Add(int64(len(parameters.Properties))) + + if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { + return err + } else { + query := sql.FormatRelationshipPartitionUpsert(graphTarget) + + if _, err := s.innerTransaction.tx.Exec(s.ctx, query, parameters.Format(graphTarget)...); err != nil { + return err + } + } + + return nil +} + +func (s *batch) tryFlushRelationshipUpdateByBuffer() error { + if updateBatch, err := sql.ValidateRelationshipUpdateByBatch(s.relationshipUpdateByBuffer); err != nil { + return err + } else if err := s.flushRelationshipUpdateByBuffer(updateBatch); err != nil { + return err + } + + s.relationshipUpdateByBuffer = s.relationshipUpdateByBuffer[:0] + return nil +} + +type relationshipCreateBatch struct { + startIDs []uint32 + endIDs []uint32 + edgeKindIDs []int16 + edgePropertyBags []pgtype.JSONB +} + +func newRelationshipCreateBatch(size int) *relationshipCreateBatch { + return &relationshipCreateBatch{ + startIDs: make([]uint32, 0, size), + endIDs: make([]uint32, 0, size), + edgeKindIDs: make([]int16, 0, size), + edgePropertyBags: make([]pgtype.JSONB, 0, size), + } +} + +func (s *relationshipCreateBatch) Add(startID, endID uint32, edgeKindID int16) { + s.startIDs = append(s.startIDs, startID) + s.edgeKindIDs = append(s.edgeKindIDs, edgeKindID) + s.endIDs = append(s.endIDs, endID) +} + +func (s *relationshipCreateBatch) EncodeProperties(edgePropertiesBatch []*graph.Properties) error { + for _, edgeProperties := range edgePropertiesBatch { + if propertiesJSONB, err := pgsql.PropertiesToJSONB(edgeProperties); err != nil { + return err + } else { + s.edgePropertyBags = append(s.edgePropertyBags, propertiesJSONB) + } + } + + return nil +} + +type relationshipCreateBatchBuilder struct { + keyToEdgeID map[string]uint32 + relationshipUpdateBatch *relationshipCreateBatch + edgePropertiesIndex map[uint32]int + edgePropertiesBatch []*graph.Properties +} + +func newRelationshipCreateBatchBuilder(size int) *relationshipCreateBatchBuilder { + return &relationshipCreateBatchBuilder{ + keyToEdgeID: map[string]uint32{}, + relationshipUpdateBatch: newRelationshipCreateBatch(size), + edgePropertiesIndex: map[uint32]int{}, + } +} + +func (s *relationshipCreateBatchBuilder) Build() (*relationshipCreateBatch, error) { + return s.relationshipUpdateBatch, s.relationshipUpdateBatch.EncodeProperties(s.edgePropertiesBatch) +} + +func (s *relationshipCreateBatchBuilder) Add(kindMapper KindMapper, edge *graph.Relationship) error { + keyBuilder := strings.Builder{} + + keyBuilder.WriteString(edge.StartID.String()) + keyBuilder.WriteString(edge.EndID.String()) + keyBuilder.WriteString(edge.Kind.String()) + + key := keyBuilder.String() + + if existingPropertiesIdx, hasExisting := s.keyToEdgeID[key]; hasExisting { + s.edgePropertiesBatch[existingPropertiesIdx].Merge(edge.Properties) + } else { + var ( + startID = edge.StartID.Uint32() + edgeID = edge.ID.Uint32() + endID = edge.EndID.Uint32() + edgeProperties = edge.Properties.Clone() + ) + + if edgeKindID, hasKind := kindMapper.MapKind(edge.Kind); !hasKind { + return fmt.Errorf("unable to map kind %s", edge.Kind) + } else { + s.relationshipUpdateBatch.Add(startID, endID, edgeKindID) + } + + s.keyToEdgeID[key] = edgeID + + s.edgePropertiesBatch = append(s.edgePropertiesBatch, edgeProperties) + s.edgePropertiesIndex[edgeID] = len(s.edgePropertiesBatch) - 1 + } + + return nil +} + +func (s *batch) flushRelationshipCreateBuffer() error { + batchBuilder := newRelationshipCreateBatchBuilder(len(s.relationshipCreateBuffer)) + + for _, nextRel := range s.relationshipCreateBuffer { + if err := batchBuilder.Add(s.schemaManager, nextRel); err != nil { + return err + } + } + + if createBatch, err := batchBuilder.Build(); err != nil { + return err + } else if graphTarget, err := s.innerTransaction.getTargetGraph(); err != nil { + return err + } else if _, err := s.innerTransaction.tx.Exec(s.ctx, createEdgeBatchStatement, graphTarget.ID, createBatch.startIDs, createBatch.endIDs, createBatch.edgeKindIDs, createBatch.edgePropertyBags); err != nil { + log.Infof("Num merged property bags: %d - Num edge keys: %d - StartID batch size: %d", len(batchBuilder.edgePropertiesIndex), len(batchBuilder.keyToEdgeID), len(batchBuilder.relationshipUpdateBatch.startIDs)) + return err + } + + s.relationshipCreateBuffer = s.relationshipCreateBuffer[:0] + return nil +} + +func (s *batch) tryFlush(batchWriteSize int) error { + if len(s.nodeUpdateByBuffer) > batchWriteSize { + if err := s.tryFlushNodeUpdateByBuffer(); err != nil { + return err + } + } + + if len(s.relationshipUpdateByBuffer) > batchWriteSize { + if err := s.tryFlushRelationshipUpdateByBuffer(); err != nil { + return err + } + } + + if len(s.relationshipCreateBuffer) > batchWriteSize { + if err := s.flushRelationshipCreateBuffer(); err != nil { + return err + } + } + + if len(s.nodeCreateBuffer) > batchWriteSize { + if err := s.flushNodeCreateBuffer(); err != nil { + return err + } + } + + if len(s.nodeDeletionBuffer) > batchWriteSize { + if err := s.flushNodeDeleteBuffer(); err != nil { + return err + } + } + + if len(s.relationshipDeletionBuffer) > batchWriteSize { + if err := s.flushRelationshipDeleteBuffer(); err != nil { + return err + } + } + + return nil +} + +func (s *batch) CreateRelationship(relationship *graph.Relationship) error { + s.relationshipCreateBuffer = append(s.relationshipCreateBuffer, relationship) + return s.tryFlush(s.batchWriteSize) +} + +func (s *batch) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) error { + return s.CreateRelationship(&graph.Relationship{ + StartID: startNodeID, + EndID: endNodeID, + Kind: kind, + Properties: properties, + }) +} + +func (s *batch) UpdateRelationshipBy(update graph.RelationshipUpdate) error { + s.relationshipUpdateByBuffer = append(s.relationshipUpdateByBuffer, update) + return s.tryFlush(s.batchWriteSize) +} + +func (s *batch) Commit() error { + if err := s.tryFlush(0); err != nil { + return err + } + + return s.innerTransaction.Commit() +} + +func (s *batch) DeleteNode(id graph.ID) error { + s.nodeDeletionBuffer = append(s.nodeDeletionBuffer, id) + return s.tryFlush(s.batchWriteSize) +} + +func (s *batch) DeleteRelationship(id graph.ID) error { + s.relationshipDeletionBuffer = append(s.relationshipDeletionBuffer, id) + return s.tryFlush(s.batchWriteSize) +} + +func (s *batch) Close() { + s.innerTransaction.Close() +} diff --git a/packages/go/dawgs/drivers/pg/driver.go b/packages/go/dawgs/drivers/pg/driver.go new file mode 100644 index 0000000000..7dd7d86a3f --- /dev/null +++ b/packages/go/dawgs/drivers/pg/driver.go @@ -0,0 +1,202 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "context" + "fmt" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "time" + + "github.com/specterops/bloodhound/dawgs/graph" +) + +var ( + readOnlyTxOptions = pgx.TxOptions{ + AccessMode: pgx.ReadOnly, + } + + readWriteTxOptions = pgx.TxOptions{ + AccessMode: pgx.ReadWrite, + } +) + +type Config struct { + Options pgx.TxOptions + QueryExecMode pgx.QueryExecMode + QueryResultFormats pgx.QueryResultFormats + BatchWriteSize int +} + +func OptionSetQueryExecMode(queryExecMode pgx.QueryExecMode) graph.TransactionOption { + return func(config *graph.TransactionConfig) { + if pgCfg, typeOK := config.DriverConfig.(*Config); typeOK { + pgCfg.QueryExecMode = queryExecMode + } + } +} + +type Driver struct { + pool *pgxpool.Pool + schemaManager *SchemaManager + defaultTransactionTimeout time.Duration + batchWriteSize int +} + +func (s *Driver) KindMapper() KindMapper { + return s.schemaManager +} + +func (s *Driver) SetDefaultGraph(ctx context.Context, graphSchema graph.Graph) error { + return s.WriteTransaction(ctx, func(tx graph.Transaction) error { + return s.schemaManager.AssertDefaultGraph(tx, graphSchema) + }) +} + +func (s *Driver) SetBatchWriteSize(size int) { + s.batchWriteSize = size +} + +func (s *Driver) SetWriteFlushSize(size int) { + // THis is a no-op function since PostgreSQL does not require transaction rotation like Neo4j does +} + +func (s *Driver) BatchOperation(ctx context.Context, batchDelegate graph.BatchDelegate) error { + if cfg, err := renderConfig(s.batchWriteSize, readWriteTxOptions, nil); err != nil { + return err + } else if conn, err := s.pool.Acquire(ctx); err != nil { + return err + } else { + defer conn.Release() + + if batch, err := newBatch(ctx, conn, s.schemaManager, cfg); err != nil { + return err + } else { + defer batch.Close() + + if err := batchDelegate(batch); err != nil { + return err + } + + return batch.Commit() + } + } +} + +func (s *Driver) Close(ctx context.Context) error { + s.pool.Close() + return nil +} + +func renderConfig(batchWriteSize int, pgxOptions pgx.TxOptions, userOptions []graph.TransactionOption) (*Config, error) { + graphCfg := graph.TransactionConfig{ + DriverConfig: &Config{ + Options: pgxOptions, + QueryExecMode: pgx.QueryExecModeCacheStatement, + QueryResultFormats: pgx.QueryResultFormats{pgx.BinaryFormatCode}, + BatchWriteSize: batchWriteSize, + }, + } + + for _, option := range userOptions { + option(&graphCfg) + } + + if graphCfg.DriverConfig != nil { + if pgCfg, typeOK := graphCfg.DriverConfig.(*Config); !typeOK { + return nil, fmt.Errorf("invalid driver config type %T", graphCfg.DriverConfig) + } else { + return pgCfg, nil + } + } + + return nil, fmt.Errorf("driver config is nil") +} + +func (s *Driver) ReadTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { + if cfg, err := renderConfig(s.batchWriteSize, readOnlyTxOptions, options); err != nil { + return err + } else if conn, err := s.pool.Acquire(ctx); err != nil { + return err + } else { + defer conn.Release() + + return txDelegate(&transaction{ + schemaManager: s.schemaManager, + queryExecMode: cfg.QueryExecMode, + ctx: ctx, + conn: conn, + targetSchemaSet: false, + }) + } +} + +func (s *Driver) WriteTransaction(ctx context.Context, txDelegate graph.TransactionDelegate, options ...graph.TransactionOption) error { + if cfg, err := renderConfig(s.batchWriteSize, readWriteTxOptions, options); err != nil { + return err + } else if conn, err := s.pool.Acquire(ctx); err != nil { + return err + } else { + defer conn.Release() + + if tx, err := newTransaction(ctx, conn, s.schemaManager, cfg); err != nil { + return err + } else { + defer tx.Close() + + if err := txDelegate(tx); err != nil { + return err + } + + return tx.Commit() + } + } +} + +func (s *Driver) FetchSchema(ctx context.Context) (graph.Schema, error) { + // TODO: This is not required for existing functionality as the SchemaManager type handles most of this negotiation + // however, in the future this function would make it easier to make schema management generic and should be + // implemented. + return graph.Schema{}, fmt.Errorf("not implemented") +} + +func (s *Driver) AssertSchema(ctx context.Context, schema graph.Schema) error { + if err := s.WriteTransaction(ctx, func(tx graph.Transaction) error { + return s.schemaManager.AssertSchema(tx, schema) + }, OptionSetQueryExecMode(pgx.QueryExecModeSimpleProtocol)); err != nil { + return err + } else { + // Resetting the pool must be done on every schema assertion as composite types may have changed OIDs + s.pool.Reset() + } + + if schema.DefaultGraph.Name == "" { + return nil + } + + return s.SetDefaultGraph(ctx, schema.DefaultGraph) +} + +func (s *Driver) Run(ctx context.Context, query string, parameters map[string]any) error { + return s.WriteTransaction(ctx, func(tx graph.Transaction) error { + result := tx.Raw(query, parameters) + defer result.Close() + + return result.Error() + }) +} diff --git a/packages/go/dawgs/drivers/pg/facts.go b/packages/go/dawgs/drivers/pg/facts.go new file mode 100644 index 0000000000..0fcf5e3481 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/facts.go @@ -0,0 +1,37 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "errors" + "github.com/jackc/pgx/v5/pgconn" +) + +type SQLState string + +func (s SQLState) String() string { + return string(s) +} + +func (s SQLState) ErrorMatches(err error) bool { + var pgConnErr *pgconn.PgError + return errors.As(err, &pgConnErr) && pgConnErr.Code == s.String() +} + +const ( + StateObjectDoesNotExist SQLState = "42704" +) diff --git a/packages/go/dawgs/drivers/pg/manager.go b/packages/go/dawgs/drivers/pg/manager.go new file mode 100644 index 0000000000..68b88ac100 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/manager.go @@ -0,0 +1,286 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "errors" + "github.com/jackc/pgx/v5" + "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + "github.com/specterops/bloodhound/dawgs/drivers/pg/query" + "github.com/specterops/bloodhound/dawgs/graph" + "sync" +) + +type KindMapper interface { + MapKindID(kindID int16) (graph.Kind, bool) + MapKindIDs(kindIDs ...int16) (graph.Kinds, []int16) + MapKind(kind graph.Kind) (int16, bool) + MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) + AssertKinds(tx graph.Transaction, kinds graph.Kinds) ([]int16, error) +} + +type SchemaManager struct { + defaultGraph model.Graph + hasDefaultGraph bool + graphs map[string]model.Graph + kindsByID map[graph.Kind]int16 + kindIDsByKind map[int16]graph.Kind + lock *sync.RWMutex +} + +func NewSchemaManager() *SchemaManager { + return &SchemaManager{ + hasDefaultGraph: false, + graphs: map[string]model.Graph{}, + kindsByID: map[graph.Kind]int16{}, + kindIDsByKind: map[int16]graph.Kind{}, + lock: &sync.RWMutex{}, + } +} + +func (s *SchemaManager) fetch(tx graph.Transaction) error { + if kinds, err := query.On(tx).SelectKinds(); err != nil { + return err + } else { + s.kindsByID = kinds + + for kind, kindID := range s.kindsByID { + s.kindIDsByKind[kindID] = kind + } + } + + return nil +} + +func (s *SchemaManager) defineKinds(tx graph.Transaction, kinds graph.Kinds) error { + for _, kind := range kinds { + if kindID, err := query.On(tx).InsertOrGetKind(kind); err != nil { + return err + } else { + s.kindsByID[kind] = kindID + s.kindIDsByKind[kindID] = kind + } + } + + return nil +} + +func (s *SchemaManager) defineGraphKinds(tx graph.Transaction, schemas []graph.Graph) error { + for _, schema := range schemas { + var ( + _, missingNodeKinds = s.mapKinds(schema.Nodes) + _, missingEdgeKinds = s.mapKinds(schema.Edges) + ) + + if err := s.defineKinds(tx, missingNodeKinds); err != nil { + return err + } + + if err := s.defineKinds(tx, missingEdgeKinds); err != nil { + return err + } + } + + return nil +} + +func (s *SchemaManager) mapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { + var ( + missingKinds = make(graph.Kinds, 0, len(kinds)) + ids = make([]int16, 0, len(kinds)) + ) + + for _, kind := range kinds { + if id, hasID := s.kindsByID[kind]; hasID { + ids = append(ids, id) + } else { + missingKinds = append(missingKinds, kind) + } + } + + return ids, missingKinds +} + +func (s *SchemaManager) MapKind(kind graph.Kind) (int16, bool) { + s.lock.RLock() + defer s.lock.RUnlock() + + id, hasID := s.kindsByID[kind] + return id, hasID +} + +func (s *SchemaManager) MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.mapKinds(kinds) +} + +func (s *SchemaManager) mapKindIDs(kindIDs []int16) (graph.Kinds, []int16) { + var ( + missingIDs = make([]int16, 0, len(kindIDs)) + kinds = make(graph.Kinds, 0, len(kindIDs)) + ) + + for _, kindID := range kindIDs { + if kind, hasKind := s.kindIDsByKind[kindID]; hasKind { + kinds = append(kinds, kind) + } else { + missingIDs = append(missingIDs, kindID) + } + } + + return kinds, missingIDs +} + +func (s *SchemaManager) MapKindID(kindID int16) (graph.Kind, bool) { + s.lock.RLock() + defer s.lock.RUnlock() + + kind, hasKind := s.kindIDsByKind[kindID] + return kind, hasKind +} + +func (s *SchemaManager) MapKindIDs(kindIDs ...int16) (graph.Kinds, []int16) { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.mapKindIDs(kindIDs) +} + +func (s *SchemaManager) AssertKinds(tx graph.Transaction, kinds graph.Kinds) ([]int16, error) { + // Acquire a read-lock first to fast-pass validate if we're missing any kind definitions + s.lock.RLock() + + if kindIDs, missingKinds := s.mapKinds(kinds); len(missingKinds) == 0 { + // All kinds are defined. Release the read-lock here before returning + s.lock.RUnlock() + return kindIDs, nil + } + + // Release the read-lock here so that we can acquire a write-lock + s.lock.RUnlock() + + // Acquire a write-lock and release on-exit + s.lock.Lock() + defer s.lock.Unlock() + + // We have to re-acquire the missing kinds since there's a potential for another writer to acquire the write-lock + // in between release of the read-lock and acquisition of the write-lock for this operation + _, missingKinds := s.mapKinds(kinds) + + if err := s.defineKinds(tx, missingKinds); err != nil { + return nil, err + } + + kindIDs, _ := s.mapKinds(kinds) + return kindIDs, nil +} + +func (s *SchemaManager) AssertDefaultGraph(tx graph.Transaction, schema graph.Graph) error { + if graphInstance, err := s.AssertGraph(tx, schema); err != nil { + return err + } else { + s.lock.Lock() + defer s.lock.Unlock() + + s.defaultGraph = graphInstance + s.hasDefaultGraph = true + } + + return nil +} + +func (s *SchemaManager) DefaultGraph() (model.Graph, bool) { + s.lock.RLock() + defer s.lock.RUnlock() + + return s.defaultGraph, s.hasDefaultGraph +} + +func (s *SchemaManager) AssertGraph(tx graph.Transaction, schema graph.Graph) (model.Graph, error) { + // Acquire a read-lock first to fast-pass validate if we're missing the graph definitions + s.lock.RLock() + + if graphInstance, isDefined := s.graphs[schema.Name]; isDefined { + // The graph is defined. Release the read-lock here before returning + s.lock.RUnlock() + return graphInstance, nil + } + + // Release the read-lock here so that we can acquire a write-lock + s.lock.RUnlock() + + // Acquire a write-lock and create the graph definition + s.lock.Lock() + defer s.lock.Unlock() + + if graphInstance, isDefined := s.graphs[schema.Name]; isDefined { + // The graph was defined by a different actor between the read unlock and the write lock. + return graphInstance, nil + } + + // Validate the schema if the graph already exists in the database + if definition, err := query.On(tx).SelectGraphByName(schema.Name); err != nil { + // ErrNoRows signifies that this graph must be created + if !errors.Is(err, pgx.ErrNoRows) { + return model.Graph{}, err + } + } else if assertedDefinition, err := query.On(tx).AssertGraph(schema, definition); err != nil { + return model.Graph{}, err + } else { + s.graphs[schema.Name] = assertedDefinition + return assertedDefinition, nil + } + + // Create the graph + if definition, err := query.On(tx).CreateGraph(schema); err != nil { + return model.Graph{}, err + } else { + s.graphs[schema.Name] = definition + return definition, nil + } +} + +func (s *SchemaManager) AssertSchema(tx graph.Transaction, schema graph.Schema) error { + s.lock.Lock() + defer s.lock.Unlock() + + if err := query.On(tx).CreateSchema(); err != nil { + return err + } + + if err := s.fetch(tx); err != nil { + return err + } + + for _, graphSchema := range schema.Graphs { + if _, missingKinds := s.mapKinds(graphSchema.Nodes); len(missingKinds) > 0 { + if err := s.defineKinds(tx, missingKinds); err != nil { + return err + } + } + + if _, missingKinds := s.mapKinds(graphSchema.Edges); len(missingKinds) > 0 { + if err := s.defineKinds(tx, missingKinds); err != nil { + return err + } + } + } + + return nil +} diff --git a/packages/go/dawgs/drivers/pg/mapper.go b/packages/go/dawgs/drivers/pg/mapper.go new file mode 100644 index 0000000000..b4a1840788 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/mapper.go @@ -0,0 +1,82 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "fmt" + "github.com/specterops/bloodhound/dawgs/graph" +) + +func mapValue(kindMapper KindMapper) func(rawValue, target any) (bool, error) { + return func(rawValue, target any) (bool, error) { + switch typedTarget := target.(type) { + case *graph.Relationship: + if compositeMap, typeOK := rawValue.(map[string]any); !typeOK { + return false, fmt.Errorf("unexpected edge composite backing type: %T", rawValue) + } else { + edge := edgeComposite{} + + if edge.TryMap(compositeMap) { + if err := edge.ToRelationship(kindMapper, typedTarget); err != nil { + return false, err + } + } else { + return false, nil + } + } + + case *graph.Node: + if compositeMap, typeOK := rawValue.(map[string]any); !typeOK { + return false, fmt.Errorf("unexpected node composite backing type: %T", rawValue) + } else { + node := nodeComposite{} + + if node.TryMap(compositeMap) { + if err := node.ToNode(kindMapper, typedTarget); err != nil { + return false, err + } + } else { + return false, nil + } + } + + case *graph.Path: + if compositeMap, typeOK := rawValue.(map[string]any); !typeOK { + return false, fmt.Errorf("unexpected node composite backing type: %T", rawValue) + } else { + path := pathComposite{} + + if path.TryMap(compositeMap) { + if err := path.ToPath(kindMapper, typedTarget); err != nil { + return false, err + } + } else { + return false, nil + } + } + + default: + return false, nil + } + + return true, nil + } +} + +func NewValueMapper(values []any, kindMapper KindMapper) graph.ValueMapper { + return graph.NewValueMapper(values, mapValue(kindMapper)) +} diff --git a/packages/go/dawgs/drivers/pg/model/format.go b/packages/go/dawgs/drivers/pg/model/format.go new file mode 100644 index 0000000000..62feb7a4b7 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/model/format.go @@ -0,0 +1,62 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "github.com/specterops/bloodhound/dawgs/graph" + "strconv" + "strings" +) + +const ( + NodeTable = "node" + EdgeTable = "edge" +) + +func partitionTableName(parent string, graphID int32) string { + return parent + "_" + strconv.FormatInt(int64(graphID), 10) +} + +func NodePartitionTableName(graphID int32) string { + return partitionTableName(NodeTable, graphID) +} + +func EdgePartitionTableName(graphID int32) string { + return partitionTableName(EdgeTable, graphID) +} + +func IndexName(table string, index graph.Index) string { + stringBuilder := strings.Builder{} + + stringBuilder.WriteString(table) + stringBuilder.WriteString("_") + stringBuilder.WriteString(index.Field) + stringBuilder.WriteString("_index") + + return stringBuilder.String() +} + +func ConstraintName(table string, constraint graph.Constraint) string { + stringBuilder := strings.Builder{} + + stringBuilder.WriteString(table) + stringBuilder.WriteString("_") + stringBuilder.WriteString(constraint.Field) + stringBuilder.WriteString("_constraint") + + return stringBuilder.String() +} diff --git a/packages/go/dawgs/drivers/pg/model/model.go b/packages/go/dawgs/drivers/pg/model/model.go new file mode 100644 index 0000000000..86affea2ae --- /dev/null +++ b/packages/go/dawgs/drivers/pg/model/model.go @@ -0,0 +1,84 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package model + +import ( + "github.com/specterops/bloodhound/dawgs/graph" +) + +type IndexChangeSet struct { + NodeIndexesToRemove []string + EdgeIndexesToRemove []string + NodeConstraintsToRemove []string + EdgeConstraintsToRemove []string + NodeIndexesToAdd map[string]graph.Index + EdgeIndexesToAdd map[string]graph.Index + NodeConstraintsToAdd map[string]graph.Constraint + EdgeConstraintsToAdd map[string]graph.Constraint +} + +func NewIndexChangeSet() IndexChangeSet { + return IndexChangeSet{ + NodeIndexesToAdd: map[string]graph.Index{}, + NodeConstraintsToAdd: map[string]graph.Constraint{}, + EdgeIndexesToAdd: map[string]graph.Index{}, + EdgeConstraintsToAdd: map[string]graph.Constraint{}, + } +} + +type GraphPartition struct { + Name string + Indexes map[string]graph.Index + Constraints map[string]graph.Constraint +} + +func NewGraphPartition(name string) GraphPartition { + return GraphPartition{ + Name: name, + Indexes: map[string]graph.Index{}, + Constraints: map[string]graph.Constraint{}, + } +} + +func NewGraphPartitionFromSchema(name string, indexes []graph.Index, constraints []graph.Constraint) GraphPartition { + graphPartition := GraphPartition{ + Name: name, + Indexes: make(map[string]graph.Index, len(indexes)), + Constraints: make(map[string]graph.Constraint, len(constraints)), + } + + for _, index := range indexes { + graphPartition.Indexes[IndexName(name, index)] = index + } + + for _, constraint := range constraints { + graphPartition.Constraints[ConstraintName(name, constraint)] = constraint + } + + return graphPartition +} + +type GraphPartitions struct { + Node GraphPartition + Edge GraphPartition +} + +type Graph struct { + ID int32 + Name string + Partitions GraphPartitions +} diff --git a/packages/go/dawgs/drivers/pg/node.go b/packages/go/dawgs/drivers/pg/node.go new file mode 100644 index 0000000000..c00d8f2b14 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/node.go @@ -0,0 +1,225 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "bytes" + "context" + "github.com/specterops/bloodhound/cypher/backend/pgsql" + "github.com/specterops/bloodhound/cypher/backend/pgsql/pgtransition" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/dawgs/query" +) + +type liveQuery struct { + ctx context.Context + tx graph.Transaction + kindMapper KindMapper + emitter *pgsql.Emitter + parameters map[string]any + queryBuilder *query.Builder +} + +func newLiveQuery(ctx context.Context, tx graph.Transaction, kindMapper KindMapper) liveQuery { + return liveQuery{ + ctx: ctx, + tx: tx, + kindMapper: kindMapper, + emitter: pgsql.NewEmitter(false, kindMapper), + parameters: map[string]any{}, + queryBuilder: query.NewBuilder(nil), + } +} + +func (s *liveQuery) runAllShortestPathsQuery() graph.Result { + if aspArguments, err := pgtransition.TranslateAllShortestPaths(s.queryBuilder.RegularQuery(), s.kindMapper); err != nil { + return graph.NewErrorResult(err) + } else { + return s.tx.Raw(`select (t.nodes, t.edges)::pathComposite from all_shortest_paths(@p1, @p2, @p3, @p4) as t`, map[string]any{ + "p1": aspArguments.RootCriteria, + "p2": aspArguments.TraversalCriteria, + "p3": aspArguments.TerminalCriteria, + "p4": aspArguments.MaxDepth, + }) + } +} + +func (s *liveQuery) runRegularQuery() graph.Result { + buffer := &bytes.Buffer{} + + if regularQuery, err := s.queryBuilder.Build(); err != nil { + return graph.NewErrorResult(err) + } else if arguments, err := pgsql.Translate(regularQuery, s.kindMapper); err != nil { + return graph.NewErrorResult(err) + } else if err := s.emitter.Write(regularQuery, buffer); err != nil { + return graph.NewErrorResult(err) + } else { + return s.tx.Raw(buffer.String(), arguments) + } +} + +func (s *liveQuery) Query(delegate func(results graph.Result) error, finalCriteria ...graph.Criteria) error { + for _, criteria := range finalCriteria { + s.queryBuilder.Apply(criteria) + } + + if result := s.runRegularQuery(); result.Error() != nil { + return result.Error() + } else { + defer result.Close() + return delegate(result) + } +} + +func (s *liveQuery) exec(finalCriteria ...graph.Criteria) error { + return s.Query(func(results graph.Result) error { + return results.Error() + }, finalCriteria...) +} + +type nodeQuery struct { + liveQuery +} + +func (s *nodeQuery) Filter(criteria graph.Criteria) graph.NodeQuery { + s.queryBuilder.Apply(query.Where(criteria)) + return s +} + +func (s *nodeQuery) Filterf(criteriaDelegate graph.CriteriaProvider) graph.NodeQuery { + return s.Filter(criteriaDelegate()) +} + +func (s *nodeQuery) Delete() error { + return s.exec(query.Delete( + query.Node(), + )) +} + +func (s *nodeQuery) Update(properties *graph.Properties) error { + return s.exec(query.Updatef(func() graph.Criteria { + var updateStatements []graph.Criteria + + if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { + updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) + } + + if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { + updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) + } + + return updateStatements + })) +} + +func (s *nodeQuery) OrderBy(criteria ...graph.Criteria) graph.NodeQuery { + s.queryBuilder.Apply(query.OrderBy(criteria...)) + return s +} + +func (s *nodeQuery) Offset(offset int) graph.NodeQuery { + s.queryBuilder.Apply(query.Offset(offset)) + return s +} + +func (s *nodeQuery) Limit(limit int) graph.NodeQuery { + s.queryBuilder.Apply(query.Limit(limit)) + return s +} + +func (s *nodeQuery) Count() (int64, error) { + var count int64 + + return count, s.Query(func(results graph.Result) error { + if !results.Next() { + return graph.ErrNoResultsFound + } + + return results.Scan(&count) + }, query.Returning( + query.Count(query.Node()), + )) +} + +func (s *nodeQuery) First() (*graph.Node, error) { + var node graph.Node + + return &node, s.Query( + func(results graph.Result) error { + if !results.Next() { + return graph.ErrNoResultsFound + } + + return results.Scan(&node) + }, + query.Returning( + query.Node(), + ), + query.Limit(1), + ) +} + +func (s *nodeQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Node]) error) error { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (*graph.Node, error) { + var node graph.Node + return &node, scanner.Scan(&node) + }) + + defer cursor.Close() + return delegate(cursor) + }, query.Returning( + query.Node(), + )) +} + +func (s *nodeQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.ID, error) { + var nodeID graph.ID + return nodeID, scanner.Scan(&nodeID) + }) + + defer cursor.Close() + return delegate(cursor) + }, query.Returning( + query.NodeID(), + )) +} + +func (s *nodeQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.KindsResult]) error) error { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.KindsResult, error) { + var ( + nodeID graph.ID + nodeKinds graph.Kinds + err = scanner.Scan(&nodeID, &nodeKinds) + ) + + return graph.KindsResult{ + ID: nodeID, + Kinds: nodeKinds, + }, err + }) + + defer cursor.Close() + return delegate(cursor) + }, query.Returning( + query.NodeID(), + query.KindsOf(query.Node()), + )) +} diff --git a/packages/go/dawgs/drivers/pg/node_test.go b/packages/go/dawgs/drivers/pg/node_test.go new file mode 100644 index 0000000000..cd74d65ce5 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/node_test.go @@ -0,0 +1,99 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "context" + "github.com/specterops/bloodhound/dawgs/graph" + graph_mocks "github.com/specterops/bloodhound/dawgs/graph/mocks" + "github.com/specterops/bloodhound/dawgs/query" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "testing" +) + +type testKindMapper struct { + known map[string]int16 +} + +func (s testKindMapper) MapKindID(kindID int16) (graph.Kind, bool) { + panic("implement me") +} + +func (s testKindMapper) MapKindIDs(kindIDs ...int16) (graph.Kinds, []int16) { + panic("implement me") +} + +func (s testKindMapper) MapKind(kind graph.Kind) (int16, bool) { + panic("implement me") +} + +func (s testKindMapper) AssertKinds(tx graph.Transaction, kinds graph.Kinds) ([]int16, error) { + panic("implement me") +} + +func (s testKindMapper) MapKinds(kinds graph.Kinds) ([]int16, graph.Kinds) { + var ( + kindIDs = make([]int16, 0, len(kinds)) + missingKinds = make([]graph.Kind, 0, len(kinds)) + ) + + for _, kind := range kinds { + if kindID, hasKind := s.known[kind.String()]; hasKind { + kindIDs = append(kindIDs, kindID) + } else { + missingKinds = append(missingKinds, kind) + } + } + + return kindIDs, missingKinds +} + +func TestNodeQuery(t *testing.T) { + var ( + mockCtrl = gomock.NewController(t) + mockTx = graph_mocks.NewMockTransaction(mockCtrl) + mockResult = graph_mocks.NewMockResult(mockCtrl) + + kindMapper = testKindMapper{ + known: map[string]int16{ + "NodeKindA": 1, + "NodeKindB": 2, + "EdgeKindA": 3, + "EdgeKindB": 4, + }, + } + + nodeQueryInst = &nodeQuery{ + liveQuery: newLiveQuery(context.Background(), mockTx, kindMapper), + } + ) + + mockTx.EXPECT().Raw("select (n.id, n.kind_ids, n.properties)::nodeComposite as n from node as n where (n.properties->>'prop')::text = @p0 limit 1", gomock.Any()).Return(mockResult) + + mockResult.EXPECT().Error().Return(nil) + mockResult.EXPECT().Next().Return(true) + mockResult.EXPECT().Close().Return() + mockResult.EXPECT().Scan(gomock.Any()).Return(nil) + + nodeQueryInst.Filter( + query.Equals(query.NodeProperty("prop"), "1234"), + ) + + _, err := nodeQueryInst.First() + require.Nil(t, err) +} diff --git a/packages/go/dawgs/drivers/pg/pg.go b/packages/go/dawgs/drivers/pg/pg.go new file mode 100644 index 0000000000..71fae84d06 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/pg.go @@ -0,0 +1,110 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "context" + "fmt" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/specterops/bloodhound/cypher/model/pg" + "github.com/specterops/bloodhound/dawgs" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/log" + "time" +) + +const ( + DriverName = "pg" + + poolInitConnectionTimeout = time.Second * 10 + defaultTransactionTimeout = time.Minute * 15 + defaultBatchWriteSize = 20_000 +) + +func afterPooledConnectionEstablished(ctx context.Context, conn *pgx.Conn) error { + log.Debugf("Established a new database connection.") + + for _, dataType := range pg.CompositeTypes { + if definition, err := conn.LoadType(ctx, dataType.String()); err != nil { + if !StateObjectDoesNotExist.ErrorMatches(err) { + return fmt.Errorf("failed to match composite type %s to database: %w", dataType, err) + } + } else { + conn.TypeMap().RegisterType(definition) + } + } + + return nil +} + +func afterPooledConnectionRelease(conn *pgx.Conn) bool { + for _, dataType := range pg.CompositeTypes { + if _, hasType := conn.TypeMap().TypeForName(dataType.String()); !hasType { + // This connection should be destroyed since it does not contain information regarding the schema's + // composite types + log.Warnf("Unable to find expected data type: %s. This database connection will not be pooled.", dataType) + return false + } + } + + return true +} + +func newDatabase(connectionString string) (*Driver, error) { + poolCtx, done := context.WithTimeout(context.Background(), poolInitConnectionTimeout) + defer done() + + if poolCfg, err := pgxpool.ParseConfig(connectionString); err != nil { + return nil, err + } else { + // TODO: Min and Max connections for the pool should be configurable + poolCfg.MinConns = 5 + poolCfg.MaxConns = 50 + + // Bind functions to the AfterConnect and AfterRelease hooks to ensure that composite type registration occurs. + // Without composite type registration, the pgx connection type will not be able to marshal PG OIDs to their + // respective Golang structs. + poolCfg.AfterConnect = afterPooledConnectionEstablished + poolCfg.AfterRelease = afterPooledConnectionRelease + + if pool, err := pgxpool.NewWithConfig(poolCtx, poolCfg); err != nil { + return nil, err + } else { + return &Driver{ + pool: pool, + schemaManager: NewSchemaManager(), + defaultTransactionTimeout: defaultTransactionTimeout, + batchWriteSize: defaultBatchWriteSize, + }, nil + } + } +} + +func init() { + dawgs.Register(DriverName, func(ctx context.Context, cfg dawgs.Config) (graph.Database, error) { + if connectionString, typeOK := cfg.DriverCfg.(string); !typeOK { + return nil, fmt.Errorf("expected string for configuration type but got %T", cfg) + } else if graphDB, err := newDatabase(connectionString); err != nil { + return nil, err + } else if err := graphDB.AssertSchema(ctx, graph.Schema{}); err != nil { + return nil, err + } else { + return graphDB, nil + } + }) +} diff --git a/packages/go/dawgs/drivers/pg/query/definitions.go b/packages/go/dawgs/drivers/pg/query/definitions.go new file mode 100644 index 0000000000..4968d52a8a --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/definitions.go @@ -0,0 +1,37 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package query + +import "regexp" + +var ( + pgPropertyIndexRegex = regexp.MustCompile(`(?i)^create\s+(unique)?(?:\s+)?index\s+([^ ]+)\s+on\s+\S+\s+using\s+([^ ]+)\s+\(+properties\s+->>\s+'([^:]+)::.+$`) + pgColumnIndexRegex = regexp.MustCompile(`(?i)^create\s+(unique)?(?:\s+)?index\s+([^ ]+)\s+on\s+\S+\s+using\s+([^ ]+)\s+\(([^)]+)\)$`) +) + +const ( + pgIndexRegexGroupUnique = 1 + pgIndexRegexGroupName = 2 + pgIndexRegexGroupIndexType = 3 + pgIndexRegexGroupFields = 4 + pgIndexRegexNumExpectedGroups = 5 + + pgIndexTypeBTree = "btree" + pgIndexTypeGIN = "gin" + pgIndexUniqueStr = "unique" + pgPropertiesColumn = "properties" +) diff --git a/packages/go/dawgs/drivers/pg/query/format.go b/packages/go/dawgs/drivers/pg/query/format.go new file mode 100644 index 0000000000..ac35268d5b --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/format.go @@ -0,0 +1,336 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package query + +import ( + "fmt" + "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/graphschema/ad" + "strconv" + "strings" +) + +func postgresIndexType(indexType graph.IndexType) string { + switch indexType { + case graph.BTreeIndex: + return pgIndexTypeBTree + case graph.TextSearchIndex: + return pgIndexTypeGIN + default: + return "NOT SUPPORTED" + } +} + +func parsePostgresIndexType(pgType string) graph.IndexType { + switch strings.ToLower(pgType) { + case pgIndexTypeBTree: + return graph.BTreeIndex + case pgIndexTypeGIN: + return graph.TextSearchIndex + default: + return graph.UnsupportedIndex + } +} + +func join(values ...string) string { + return strings.Join(values, "") +} + +func formatDropPropertyIndex(indexName string) string { + return join("drop index if exists ", indexName, ";") +} + +func formatDropPropertyConstraint(constraintName string) string { + return join("drop index if exists ", constraintName, ";") +} + +func formatCreatePropertyConstraint(constraintName, tableName, fieldName string, indexType graph.IndexType) string { + pgIndexType := postgresIndexType(indexType) + + return join("create unique index ", constraintName, " on ", tableName, " using ", + pgIndexType, " ((", tableName, ".", pgPropertiesColumn, " ->> '", fieldName, "'));") +} + +func formatCreatePropertyIndex(indexName, tableName, fieldName string, indexType graph.IndexType) string { + var ( + pgIndexType = postgresIndexType(indexType) + queryPartial = join("create index ", indexName, " on ", tableName, " using ", + pgIndexType, " ((", tableName, ".", pgPropertiesColumn, " ->> '", fieldName) + ) + + if indexType == graph.TextSearchIndex { + // GIN text search requires the column to be typed and to contain the tri-gram operation extension + return join(queryPartial, "'::text) gin_trgm_ops);") + } else { + return join(queryPartial, "'));") + } +} + +func formatCreatePartitionTable(name, parent string, graphID int32) string { + builder := strings.Builder{} + + builder.WriteString("create table ") + builder.WriteString(name) + builder.WriteString(" partition of ") + builder.WriteString(parent) + builder.WriteString(" for values in (") + builder.WriteString(strconv.FormatInt(int64(graphID), 10)) + builder.WriteString(")") + + return builder.String() +} + +func formatConflictMatcher(propertyNames []string, defaultOnConflict string) string { + builder := strings.Builder{} + builder.WriteString("on conflict (") + + if len(propertyNames) > 0 { + for idx, propertyName := range propertyNames { + if idx > 0 { + builder.WriteString(", ") + } + + builder.WriteString("(properties->>'") + builder.WriteString(propertyName) + builder.WriteString("')") + } + } else { + builder.WriteString(defaultOnConflict) + } + + builder.WriteString(") ") + return builder.String() +} + +type idCounter struct { + current int +} + +func newIDCounter() *idCounter { + return &idCounter{ + current: 1, + } +} + +func (s *idCounter) Next() string { + next := s.current + s.current += 1 + + return strconv.Itoa(next) +} + +func FormatNodeUpsert(graphTarget model.Graph, identityProperties []string) string { + return join( + "insert into ", graphTarget.Partitions.Node.Name, " as n ", + "(graph_id, kind_ids, properties) ", + "select $1::int4, unnest($2::text[])::int2[], unnest($3::jsonb[]) ", + formatConflictMatcher(identityProperties, "id, graph_id"), + "do update set properties = n.properties || excluded.properties, kind_ids = uniq(sort(n.kind_ids || excluded.kind_ids)) ", + "returning id;", + ) +} + +func FormatRelationshipPartitionUpsert(graphTarget model.Graph) string { + return join( + "merge into ", graphTarget.Partitions.Edge.Name, " as e ", + "using (select $1::int4 as gid, unnest($2::int4[]) as sid, unnest($3::int4[]) as eid, unnest($4::int2[]) as kid, unnest($5::jsonb[]) as p) as ei ", + "on e.start_id = ei.sid and e.end_id = ei.eid and e.kind_id = ei.kid ", + "when matched then update set properties = e.properties || ei.p ", + "when not matched then insert (graph_id, start_id, end_id, kind_id, properties) values (ei.gid, ei.sid, ei.eid, ei.kid, ei.p);", + ) +} + +type NodeUpdate struct { + IDFuture *Future[graph.ID] + Node *graph.Node +} + +// NodeUpdateBatch +// +// TODO: See note below +// +// Some assumptions were made here regarding identity kind matching since this data model does not directly require the +// kind of a node to enforce a constraint +type NodeUpdateBatch struct { + IdentityProperties []string + Updates map[string]*NodeUpdate +} + +func NewNodeUpdateBatch() *NodeUpdateBatch { + return &NodeUpdateBatch{ + Updates: map[string]*NodeUpdate{}, + } +} + +func (s *NodeUpdateBatch) Add(update graph.NodeUpdate) (*Future[graph.ID], error) { + if len(s.IdentityProperties) > 0 && len(update.IdentityProperties) != len(s.IdentityProperties) { + return nil, fmt.Errorf("node update mixes identity properties with pre-existing updates") + } + + for _, expectedIdentityProperty := range s.IdentityProperties { + found := false + + for _, updateIdentityProperty := range update.IdentityProperties { + if expectedIdentityProperty == updateIdentityProperty { + found = true + break + } + } + + if !found { + return nil, fmt.Errorf("node update mixes identity properties with pre-existing updates") + } + } + + if key, err := update.Key(); err != nil { + return nil, err + } else { + update.Node.AddKinds(update.IdentityKind) + + if len(s.IdentityProperties) == 0 { + s.IdentityProperties = make([]string, len(update.IdentityProperties)) + copy(s.IdentityProperties, update.IdentityProperties) + } + + if existingUpdate, hasExisting := s.Updates[key]; hasExisting { + existingUpdate.Node.Merge(update.Node) + return existingUpdate.IDFuture, nil + } else { + newIDFuture := NewFuture(graph.ID(0)) + + s.Updates[key] = &NodeUpdate{ + IDFuture: newIDFuture, + Node: update.Node, + } + + return newIDFuture, nil + } + + } +} + +func ValidateNodeUpdateByBatch(updates []graph.NodeUpdate) (*NodeUpdateBatch, error) { + updateBatch := NewNodeUpdateBatch() + + for _, update := range updates { + if _, err := updateBatch.Add(update); err != nil { + return nil, err + } + } + + return updateBatch, nil +} + +type Future[T any] struct { + Value T +} + +func NewFuture[T any](value T) *Future[T] { + return &Future[T]{ + Value: value, + } +} + +type RelationshipUpdate struct { + StartID *Future[graph.ID] + EndID *Future[graph.ID] + Relationship *graph.Relationship +} + +type RelationshipUpdateBatch struct { + NodeUpdates *NodeUpdateBatch + IdentityProperties []string + Updates map[string]*RelationshipUpdate +} + +func NewRelationshipUpdateBatch() *RelationshipUpdateBatch { + return &RelationshipUpdateBatch{ + NodeUpdates: NewNodeUpdateBatch(), + Updates: map[string]*RelationshipUpdate{}, + } +} + +func (s *RelationshipUpdateBatch) Add(update graph.RelationshipUpdate) error { + if len(s.IdentityProperties) > 0 && len(update.IdentityProperties) != len(s.IdentityProperties) { + return fmt.Errorf("relationship update mixes identity properties with pre-existing updates") + } + + for _, expectedIdentityProperty := range s.IdentityProperties { + found := false + + for _, updateIdentityProperty := range update.IdentityProperties { + if expectedIdentityProperty == updateIdentityProperty { + found = true + break + } + } + + if !found { + return fmt.Errorf("relationship update mixes identity properties with pre-existing updates") + } + } + + if startNodeID, err := s.NodeUpdates.Add(graph.NodeUpdate{ + Node: update.Start, + IdentityKind: update.StartIdentityKind, + IdentityProperties: update.StartIdentityProperties, + }); err != nil { + return err + } else if endNodeID, err := s.NodeUpdates.Add(graph.NodeUpdate{ + Node: update.End, + IdentityKind: update.EndIdentityKind, + IdentityProperties: update.EndIdentityProperties, + }); err != nil { + return err + } else if key, err := update.Key(); err != nil { + return err + } else { + if len(s.IdentityProperties) == 0 { + s.IdentityProperties = make([]string, len(update.IdentityProperties)) + copy(s.IdentityProperties, update.IdentityProperties) + } + + if existingUpdate, hasExisting := s.Updates[key]; hasExisting { + existingUpdate.Relationship.Merge(update.Relationship) + } else { + s.Updates[key] = &RelationshipUpdate{ + StartID: startNodeID, + EndID: endNodeID, + Relationship: update.Relationship, + } + } + } + + return nil +} + +func ValidateRelationshipUpdateByBatch(updates []graph.RelationshipUpdate) (*RelationshipUpdateBatch, error) { + updateBatch := NewRelationshipUpdateBatch() + + for _, update := range updates { + if update.Relationship.Kind.Is(ad.SQLAdmin, ad.AllowedToAct) { + update.Relationship.Kind.Is() + } + + if err := updateBatch.Add(update); err != nil { + return nil, err + } + } + + return updateBatch, nil +} diff --git a/packages/go/dawgs/drivers/pg/query/query.go b/packages/go/dawgs/drivers/pg/query/query.go new file mode 100644 index 0000000000..4ddff6e472 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/query.go @@ -0,0 +1,487 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package query + +import ( + _ "embed" + "fmt" + "github.com/jackc/pgx/v5" + "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + "github.com/specterops/bloodhound/dawgs/graph" +) + +type Query struct { + tx graph.Transaction +} + +func On(tx graph.Transaction) Query { + return Query{ + tx: tx, + } +} + +func (s Query) exec(statement string, args map[string]any) error { + result := s.tx.Raw(statement, args) + defer result.Close() + + return result.Error() +} + +func (s Query) describeGraphPartition(name string) (model.GraphPartition, error) { + graphPartition := model.NewGraphPartition(name) + + if tableIndexDefinitions, err := s.SelectTableIndexDefinitions(name); err != nil { + return graphPartition, err + } else { + for _, tableIndexDefinition := range tableIndexDefinitions { + if captureGroups := pgPropertyIndexRegex.FindStringSubmatch(tableIndexDefinition); captureGroups == nil { + // If this index does not match our expected column index format then report it as a potential error + if !pgColumnIndexRegex.MatchString(tableIndexDefinition) { + return graphPartition, fmt.Errorf("regex mis-match on schema definition: %s", tableIndexDefinition) + } + } else { + indexName := captureGroups[pgIndexRegexGroupName] + + if captureGroups[pgIndexRegexGroupUnique] == pgIndexUniqueStr { + graphPartition.Constraints[indexName] = graph.Constraint{ + Name: indexName, + Field: captureGroups[pgIndexRegexGroupFields], + Type: parsePostgresIndexType(captureGroups[pgIndexRegexGroupIndexType]), + } + } else { + graphPartition.Indexes[indexName] = graph.Index{ + Name: indexName, + Field: captureGroups[pgIndexRegexGroupFields], + Type: parsePostgresIndexType(captureGroups[pgIndexRegexGroupIndexType]), + } + } + } + } + } + + return graphPartition, nil +} + +func (s Query) SelectKinds() (map[graph.Kind]int16, error) { + var ( + kindID int16 + kindName string + + kinds = map[graph.Kind]int16{} + result = s.tx.Raw(sqlSelectKinds, nil) + ) + + defer result.Close() + + for result.Next() { + if err := result.Scan(&kindID, &kindName); err != nil { + return nil, err + } + + kinds[graph.StringKind(kindName)] = kindID + } + + return kinds, result.Error() +} + +func (s Query) selectGraphPartitions(graphID int32) (model.GraphPartitions, error) { + var ( + nodePartitionName = model.NodePartitionTableName(graphID) + edgePartitionName = model.EdgePartitionTableName(graphID) + ) + + if nodePartition, err := s.describeGraphPartition(nodePartitionName); err != nil { + return model.GraphPartitions{}, err + } else if edgePartition, err := s.describeGraphPartition(edgePartitionName); err != nil { + return model.GraphPartitions{}, err + } else { + return model.GraphPartitions{ + Node: nodePartition, + Edge: edgePartition, + }, nil + } +} + +func (s Query) selectGraphPartialByName(name string) (model.Graph, error) { + var ( + graphID int32 + result = s.tx.Raw(sqlSelectGraphByName, map[string]any{ + "name": name, + }) + ) + + defer result.Close() + + if !result.Next() { + return model.Graph{}, pgx.ErrNoRows + } + + if err := result.Scan(&graphID); err != nil { + return model.Graph{}, err + } + + return model.Graph{ + ID: graphID, + Name: name, + }, result.Error() +} + +func (s Query) SelectGraphByName(name string) (model.Graph, error) { + if definition, err := s.selectGraphPartialByName(name); err != nil { + return model.Graph{}, err + } else if graphPartitions, err := s.selectGraphPartitions(definition.ID); err != nil { + return model.Graph{}, err + } else { + definition.Partitions = graphPartitions + return definition, nil + } +} + +func (s Query) selectGraphPartials() ([]model.Graph, error) { + var ( + graphID int32 + graphName string + graphs []model.Graph + + result = s.tx.Raw(sqlSelectGraphs, nil) + ) + + defer result.Close() + + for result.Next() { + if err := result.Scan(&graphID, &graphName); err != nil { + return nil, err + } else { + graphs = append(graphs, model.Graph{ + ID: graphID, + Name: graphName, + }) + } + } + + return graphs, result.Error() +} + +func (s Query) SelectGraphs() (map[string]model.Graph, error) { + if definitions, err := s.selectGraphPartials(); err != nil { + return nil, err + } else { + indexed := map[string]model.Graph{} + + for _, definition := range definitions { + if graphPartitions, err := s.selectGraphPartitions(definition.ID); err != nil { + return nil, err + } else { + definition.Partitions = graphPartitions + indexed[definition.Name] = definition + } + } + + return indexed, nil + } +} + +func (s Query) CreatePropertyIndex(indexName, tableName, fieldName string, indexType graph.IndexType) error { + return s.exec(formatCreatePropertyIndex(indexName, tableName, fieldName, indexType), nil) +} + +func (s Query) CreatePropertyConstraint(indexName, tableName, fieldName string, indexType graph.IndexType) error { + if indexType != graph.BTreeIndex { + return fmt.Errorf("only b-tree indexing is supported for property constraints") + } + + return s.exec(formatCreatePropertyConstraint(indexName, tableName, fieldName, indexType), nil) +} + +func (s Query) DropIndex(indexName string) error { + return s.exec(formatDropPropertyIndex(indexName), nil) +} + +func (s Query) DropConstraint(constraintName string) error { + return s.exec(formatDropPropertyConstraint(constraintName), nil) +} + +func (s Query) CreateSchema() error { + if err := s.exec(sqlSchemaUp, nil); err != nil { + return err + } + + return nil +} + +func (s Query) DropSchema() error { + if err := s.exec(sqlSchemaDown, nil); err != nil { + return err + } + + return nil +} + +func (s Query) insertGraph(name string) (model.Graph, error) { + var ( + graphID int32 + result = s.tx.Raw(sqlInsertGraph, map[string]any{ + "name": name, + }) + ) + + defer result.Close() + + if !result.Next() { + return model.Graph{}, result.Error() + } + + if err := result.Scan(&graphID); err != nil { + return model.Graph{}, fmt.Errorf("failed mapping ID from graph entry creation: %w", err) + } + + return model.Graph{ + ID: graphID, + Name: name, + }, nil +} + +func (s Query) CreatePartitionTable(name, parent string, graphID int32) (model.GraphPartition, error) { + if err := s.exec(formatCreatePartitionTable(name, parent, graphID), nil); err != nil { + return model.GraphPartition{}, err + } + + return model.GraphPartition{ + Name: name, + }, nil +} + +func (s Query) SelectTableIndexDefinitions(tableName string) ([]string, error) { + var ( + definition string + definitions []string + + result = s.tx.Raw(sqlSelectTableIndexes, map[string]any{ + "tablename": tableName, + }) + ) + + defer result.Close() + + for result.Next() { + if err := result.Scan(&definition); err != nil { + return nil, err + } + + definitions = append(definitions, definition) + } + + return definitions, result.Error() +} + +func (s Query) SelectKindID(kind graph.Kind) (int16, error) { + var ( + kindID int16 + result = s.tx.Raw(sqlSelectKindID, map[string]any{ + "name": kind.String(), + }) + ) + + defer result.Close() + + if !result.Next() { + return -1, pgx.ErrNoRows + } + + if err := result.Scan(&kindID); err != nil { + return -1, err + } + + return kindID, result.Error() +} + +func (s Query) assertGraphPartitionIndexes(partitions model.GraphPartitions, indexChanges model.IndexChangeSet) error { + for _, indexToRemove := range append(indexChanges.NodeIndexesToRemove, indexChanges.EdgeIndexesToRemove...) { + if err := s.DropIndex(indexToRemove); err != nil { + return err + } + } + + for _, constraintToRemove := range append(indexChanges.NodeConstraintsToRemove, indexChanges.EdgeConstraintsToRemove...) { + if err := s.DropConstraint(constraintToRemove); err != nil { + return err + } + } + + for indexName, index := range indexChanges.NodeIndexesToAdd { + if err := s.CreatePropertyIndex(indexName, partitions.Node.Name, index.Field, index.Type); err != nil { + return err + } + } + + for constraintName, constraint := range indexChanges.NodeConstraintsToAdd { + if err := s.CreatePropertyConstraint(constraintName, partitions.Node.Name, constraint.Field, constraint.Type); err != nil { + return err + } + } + + for indexName, index := range indexChanges.EdgeIndexesToAdd { + if err := s.CreatePropertyIndex(indexName, partitions.Edge.Name, index.Field, index.Type); err != nil { + return err + } + } + + for constraintName, constraint := range indexChanges.EdgeConstraintsToAdd { + if err := s.CreatePropertyConstraint(constraintName, partitions.Edge.Name, constraint.Field, constraint.Type); err != nil { + return err + } + } + + return nil +} + +func (s Query) AssertGraph(schema graph.Graph, definition model.Graph) (model.Graph, error) { + var ( + requiredNodePartition = model.NewGraphPartitionFromSchema(definition.Partitions.Node.Name, schema.NodeIndexes, schema.NodeConstraints) + requiredEdgePartition = model.NewGraphPartitionFromSchema(definition.Partitions.Edge.Name, schema.EdgeIndexes, schema.EdgeConstraints) + indexChangeSet = model.NewIndexChangeSet() + ) + + if presentNodePartition, err := s.describeGraphPartition(definition.Partitions.Node.Name); err != nil { + return model.Graph{}, err + } else { + for presentNodeIndexName := range presentNodePartition.Indexes { + if _, hasMatchingDefinition := requiredNodePartition.Indexes[presentNodeIndexName]; !hasMatchingDefinition { + indexChangeSet.NodeIndexesToRemove = append(indexChangeSet.NodeIndexesToRemove, presentNodeIndexName) + } + } + + for presentNodeConstraintName := range presentNodePartition.Constraints { + if _, hasMatchingDefinition := requiredNodePartition.Constraints[presentNodeConstraintName]; !hasMatchingDefinition { + indexChangeSet.NodeConstraintsToRemove = append(indexChangeSet.NodeConstraintsToRemove, presentNodeConstraintName) + } + } + + for requiredNodeIndexName, requiredNodeIndex := range requiredNodePartition.Indexes { + if presentNodeIndex, hasMatchingDefinition := presentNodePartition.Indexes[requiredNodeIndexName]; !hasMatchingDefinition { + indexChangeSet.NodeIndexesToAdd[requiredNodeIndexName] = requiredNodeIndex + } else if requiredNodeIndex.Type != presentNodeIndex.Type { + indexChangeSet.NodeIndexesToRemove = append(indexChangeSet.NodeIndexesToRemove, requiredNodeIndexName) + indexChangeSet.NodeIndexesToAdd[requiredNodeIndexName] = requiredNodeIndex + } + } + + for requiredNodeConstraintName, requiredNodeConstraint := range requiredNodePartition.Constraints { + if presentNodeConstraint, hasMatchingDefinition := presentNodePartition.Constraints[requiredNodeConstraintName]; !hasMatchingDefinition { + indexChangeSet.NodeConstraintsToAdd[requiredNodeConstraintName] = requiredNodeConstraint + } else if requiredNodeConstraint.Type != presentNodeConstraint.Type { + indexChangeSet.NodeConstraintsToRemove = append(indexChangeSet.NodeConstraintsToRemove, requiredNodeConstraintName) + indexChangeSet.NodeConstraintsToAdd[requiredNodeConstraintName] = requiredNodeConstraint + } + } + } + + if presentEdgePartition, err := s.describeGraphPartition(definition.Partitions.Edge.Name); err != nil { + return model.Graph{}, err + } else { + for presentEdgeIndexName := range presentEdgePartition.Indexes { + if _, hasMatchingDefinition := requiredEdgePartition.Indexes[presentEdgeIndexName]; !hasMatchingDefinition { + indexChangeSet.EdgeIndexesToRemove = append(indexChangeSet.EdgeIndexesToRemove, presentEdgeIndexName) + } + } + + for presentEdgeConstraintName := range presentEdgePartition.Constraints { + if _, hasMatchingDefinition := requiredEdgePartition.Constraints[presentEdgeConstraintName]; !hasMatchingDefinition { + indexChangeSet.EdgeConstraintsToRemove = append(indexChangeSet.EdgeConstraintsToRemove, presentEdgeConstraintName) + } + } + + for requiredEdgeIndexName, requiredEdgeIndex := range requiredEdgePartition.Indexes { + if presentEdgeIndex, hasMatchingDefinition := presentEdgePartition.Indexes[requiredEdgeIndexName]; !hasMatchingDefinition { + indexChangeSet.EdgeIndexesToAdd[requiredEdgeIndexName] = requiredEdgeIndex + } else if requiredEdgeIndex.Type != presentEdgeIndex.Type { + indexChangeSet.EdgeIndexesToRemove = append(indexChangeSet.EdgeIndexesToRemove, requiredEdgeIndexName) + indexChangeSet.EdgeIndexesToAdd[requiredEdgeIndexName] = requiredEdgeIndex + } + } + + for requiredEdgeConstraintName, requiredEdgeConstraint := range requiredEdgePartition.Constraints { + if presentEdgeConstraint, hasMatchingDefinition := presentEdgePartition.Constraints[requiredEdgeConstraintName]; !hasMatchingDefinition { + indexChangeSet.EdgeConstraintsToAdd[requiredEdgeConstraintName] = requiredEdgeConstraint + } else if requiredEdgeConstraint.Type != presentEdgeConstraint.Type { + indexChangeSet.EdgeConstraintsToRemove = append(indexChangeSet.EdgeConstraintsToRemove, requiredEdgeConstraintName) + indexChangeSet.EdgeConstraintsToAdd[requiredEdgeConstraintName] = requiredEdgeConstraint + } + } + } + + return model.Graph{ + ID: definition.ID, + Name: definition.Name, + Partitions: model.GraphPartitions{ + Node: requiredNodePartition, + Edge: requiredEdgePartition, + }, + }, s.assertGraphPartitionIndexes(definition.Partitions, indexChangeSet) +} + +func (s Query) createGraphPartitions(definition model.Graph) (model.Graph, error) { + var ( + nodePartitionName = model.NodePartitionTableName(definition.ID) + edgePartitionName = model.EdgePartitionTableName(definition.ID) + ) + + if nodePartition, err := s.CreatePartitionTable(nodePartitionName, model.NodeTable, definition.ID); err != nil { + return model.Graph{}, err + } else { + definition.Partitions.Node = nodePartition + } + + if edgePartition, err := s.CreatePartitionTable(edgePartitionName, model.EdgeTable, definition.ID); err != nil { + return model.Graph{}, err + } else { + definition.Partitions.Edge = edgePartition + } + + return definition, nil +} + +func (s Query) CreateGraph(schema graph.Graph) (model.Graph, error) { + if definition, err := s.insertGraph(schema.Name); err != nil { + return model.Graph{}, err + } else if definition, err := s.createGraphPartitions(definition); err != nil { + return model.Graph{}, err + } else { + return s.AssertGraph(schema, definition) + } +} + +func (s Query) InsertOrGetKind(kind graph.Kind) (int16, error) { + var ( + kindID int16 + result = s.tx.Raw(sqlInsertKind, map[string]any{ + "name": kind.String(), + }) + ) + + defer result.Close() + + if !result.Next() { + return -1, pgx.ErrNoRows + } + + if err := result.Scan(&kindID); err != nil { + return -1, err + } + + return kindID, result.Error() +} diff --git a/packages/go/dawgs/drivers/pg/query/sql.go b/packages/go/dawgs/drivers/pg/query/sql.go new file mode 100644 index 0000000000..bde21c72ac --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql.go @@ -0,0 +1,71 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package query + +import ( + "embed" + "fmt" + "path" + "strings" +) + +var ( + //go:embed sql + queryFS embed.FS +) + +func stripSQLComments(multiLineContent string) string { + builder := strings.Builder{} + + for _, line := range strings.Split(multiLineContent, "\n") { + trimmedLine := strings.TrimSpace(line) + + // Strip empty and SQL comment lines + if len(trimmedLine) == 0 || strings.HasPrefix(trimmedLine, "--") { + continue + } + + builder.WriteString(trimmedLine) + builder.WriteString("\n") + } + + return builder.String() +} + +func readFile(name string) string { + if content, err := queryFS.ReadFile(name); err != nil { + panic(fmt.Sprintf("Unable to find embedded query file %s: %v", name, err)) + } else { + return stripSQLComments(string(content)) + } +} + +func loadSQL(name string) string { + return readFile(path.Join("sql", name)) +} + +var ( + sqlSchemaUp = loadSQL("schema_up.sql") + sqlSchemaDown = loadSQL("schema_down.sql") + sqlSelectTableIndexes = loadSQL("select_table_indexes.sql") + sqlSelectKindID = loadSQL("select_table_indexes.sql") + sqlSelectGraphs = loadSQL("select_graphs.sql") + sqlInsertGraph = loadSQL("insert_graph.sql") + sqlInsertKind = loadSQL("insert_or_get_kind.sql") + sqlSelectKinds = loadSQL("select_kinds.sql") + sqlSelectGraphByName = loadSQL("select_graph_by_name.sql") +) diff --git a/packages/go/dawgs/drivers/pg/query/sql/insert_graph.sql b/packages/go/dawgs/drivers/pg/query/sql/insert_graph.sql new file mode 100644 index 0000000000..9b8693adb1 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/insert_graph.sql @@ -0,0 +1,20 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- Creates a new graph and returns the resulting graph ID. +insert into graph (name) +values (@name) +returning id; diff --git a/packages/go/dawgs/drivers/pg/query/sql/insert_or_get_kind.sql b/packages/go/dawgs/drivers/pg/query/sql/insert_or_get_kind.sql new file mode 100644 index 0000000000..6340a63e7e --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/insert_or_get_kind.sql @@ -0,0 +1,28 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- Creates a new kind definition if it does not exist and returns the resulting ID. If the +-- kind already exists then the kind's assigned ID is returned. +with + existing as ( + select id from kind where kind.name = @name + ), + inserted as ( + insert into kind (name) values (@name) on conflict (name) do nothing returning id + ) +select * from existing +union +select * from inserted; diff --git a/packages/go/dawgs/drivers/pg/query/sql/schema_down.sql b/packages/go/dawgs/drivers/pg/query/sql/schema_down.sql new file mode 100644 index 0000000000..19da3f24c9 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/schema_down.sql @@ -0,0 +1,93 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- Drop triggers +drop trigger if exists delete_node_edges on node; +drop function if exists delete_node_edges; + +-- Drop functions +drop function if exists query_perf; +drop function if exists lock_details; +drop function if exists table_sizes; +drop function if exists get_node; +drop function if exists node_prop; +drop function if exists kinds; +drop function if exists has_kind; +drop function if exists mt_get_root; +drop function if exists index_utilization; +drop function if exists _format_asp_where_clause; +drop function if exists _format_asp_query; +drop function if exists _all_shortest_paths; +drop function if exists all_shortest_paths; +drop function if exists traversal_step; +drop function if exists _format_traversal_continuation_termination; +drop function if exists _format_traversal_query; +drop function if exists _format_traversal_initial_query; +drop function if exists expand_traversal_step; +drop function if exists traverse; +drop function if exists edges_to_path; +drop function if exists traverse_paths; + +-- Drop all tables in order of dependency. +drop table if exists node; +drop table if exists edge; +drop table if exists kind; +drop table if exists graph; + +-- Remove custom types +do +$$ + begin + drop type pathComposite; + exception + when undefined_object then null; + end +$$; + +do +$$ + begin + drop type nodeComposite; + exception + when undefined_object then null; + end +$$; + +do +$$ + begin + drop type edgeComposite; + exception + when undefined_object then null; + end +$$; + +do +$$ + begin + drop type _traversal_step; + exception + when undefined_object then null; + end +$$; + +-- Pull the tri-gram and intarray extensions. +drop + extension if exists pg_trgm; +drop + extension if exists intarray; +drop + extension if exists pg_stat_statements; diff --git a/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql new file mode 100644 index 0000000000..cdfa8133dc --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/schema_up.sql @@ -0,0 +1,941 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- DAWGS Property Graph Partitioned Layout for PostgreSQL + +-- Notes on TOAST: +-- +-- Graph entity properties are stored in a JSONB column at the end of the row. There is a soft-limit of 2KiB for rows in +-- a PostgreSQL database page. The database will compress this value in an attempt not to exceed this limit. Once a +-- compressed value reaches the absolute limit of what the database can do to either compact it or give it more of the +-- 8 KiB page size limit, the database evicts the value to an associated TOAST (The Oversized-Attribute Storage Technique) +-- table and creates a reference to the entry to be joined upon fetch of the row. +-- +-- TOAST comes with certain performance caveats that can affect access time anywhere from a factor 3 to 6 times. It is +-- in the best interest of the database user that the properties of a graph entity never exceed this limit in large +-- graphs. + +-- We need the tri-gram extension to create a GIN text-search index. The goal here isn't full-text search, in which +-- case ts_vector and its ilk would be more suited. This particular selection was made to support accelerated lookups +-- for "contains", "starts with" and, "ends with" comparison operations. +create extension if not exists pg_trgm; + +-- We need the intarray extension for extended integer array operations like unions. This is useful for managing kind +-- arrays for nodes. +create extension if not exists intarray; + +-- This is an optional but useful extension for validating performance of queries +-- create extension if not exists pg_stat_statements; +-- +-- create or replace function public.query_perf() +-- returns table +-- ( +-- query text, +-- calls int, +-- total_time numeric, +-- mean_time numeric, +-- percent_total_time numeric +-- ) +-- as +-- $$ +-- select query as query, +-- calls as calls, +-- round(total_exec_time::numeric, 2) as total_time, +-- round(mean_exec_time::numeric, 2) as mean_time, +-- round((100 * total_exec_time / sum(total_exec_time) over ()):: numeric, 2) as percent_total_time +-- from pg_stat_statements +-- order by total_exec_time desc +-- limit 25 +-- $$ +-- language sql +-- immutable +-- parallel safe +-- strict; + +-- Table definitions + +-- The graph table contains name to ID mappings for graphs contained within the database. Each graph ID should have +-- corresponding table partitions for the node and edge tables. +create table if not exists graph +( + id serial, + name varchar(256) not null, + primary key (id), + unique (name) +); + +-- The kind table contains name to ID mappings for graph kinds. Storage of these types is necessary to maintain search +-- capability of a database without the origin application that generated it. +create table if not exists kind +( + id smallserial, + name varchar(256) not null, + primary key (id), + unique (name) +); + +-- Node composite type +do +$$ + begin + create type nodeComposite as + ( + id integer, + kind_ids smallint[8], + properties jsonb + ); + exception + when duplicate_object then null; + end +$$; + +-- The node table is a partitioned table view that partitions over the graph ID that each node belongs to. Nodes may +-- contain a disjunction of up to 8 kinds for creating clique subsets without requiring edges. +create table if not exists node +( + id serial not null, + graph_id integer not null, + kind_ids smallint[8] not null, + properties jsonb not null, + + primary key (id, graph_id), + foreign key (graph_id) references graph (id) on delete cascade +) partition by list (graph_id); + +-- The storage strategy chosen for the properties JSONB column informs the database of the user's preference to resort +-- to creating a TOAST table entry only after there is no other possible way to inline the row attribute in the current +-- page. +alter table node + alter column properties set storage main; + +-- Index on the graph ID of each node. +create index if not exists node_graph_id_index on node using btree (graph_id); + +-- Index node kind IDs so that lookups by kind is accelerated. +create index if not exists node_kind_ids_index on node using gin (kind_ids); + +-- Edge composite type +do +$$ + begin + create type edgeComposite as + ( + id integer, + start_id integer, + end_id integer, + kind_id smallint, + properties jsonb + ); + exception + when duplicate_object then null; + end +$$; + +-- The edge table is a partitioned table view that partitions over the graph ID that each edge belongs to. +create table if not exists edge +( + id serial not null, + graph_id integer not null, + start_id integer not null, + end_id integer not null, + kind_id smallint not null, + properties jsonb not null, + + primary key (id, graph_id), + foreign key (graph_id) references graph (id) on delete cascade +) partition by list (graph_id); + +-- delete_node_edges is a trigger and associated plpgsql function to cascade delete edges when attached nodes are +-- deleted. While this could be done with a foreign key relationship, it would scope the cascade delete to individual +-- node partitions and therefore require the graph_id value of each node as part of the delete statement. +create or replace function delete_node_edges() returns trigger as +$$ +begin + delete from edge where start_id = OLD.id or end_id = OLD.id; + return null; +end +$$ + language plpgsql; + +-- Drop and create the delete_node_edges trigger for the delete_node_edges() plpgsql function. See the function comment +-- for more information. +drop trigger if exists delete_node_edges on node; +create trigger delete_node_edges + after delete + on node + for each row +execute procedure delete_node_edges(); + + +-- The storage strategy chosen for the properties JSONB column informs the database of the user's preference to resort +-- to creating a TOAST table entry only after there is no other possible way to inline the row attribute in the current +-- page. +alter table edge + alter column properties set storage main; + + +-- Index on the graph ID of each edge. +create index if not exists edge_graph_id_index on edge using btree (graph_id); + +-- Index on the start vertex of each edge. +create index if not exists edge_start_id_index on edge using btree (start_id); + +-- Index on the start vertex of each edge. +create index if not exists edge_end_id_index on edge using btree (end_id); + +-- Index on the kind of each edge. +create index if not exists edge_kind_index on edge using btree (kind_id); + +-- Path composite type +do +$$ + begin + create type pathComposite as + ( + nodes nodeComposite[], + edges edgeComposite[] + ); + exception + when duplicate_object then null; + end +$$; + +-- Database helper functions +create or replace function public.lock_details() + returns table + ( + datname text, + locktype text, + relation text, + lock_mode text, + txid xid, + virtual_txid text, + pid integer, + tx_granted bool, + client_addr text, + client_port integer, + elapsed_time interval + ) +as +$$ +select db.datname as datname, + locktype as locktype, + relation::regclass as relation, + mode as lock_mode, + transactionid as txid, + virtualtransaction as virtual_txid, + l.pid as pid, + granted as tx_granted, + psa.client_addr as client_addr, + psa.client_port as client_port, + now() - psa.query_start as elapsed_time +from pg_catalog.pg_locks l + left join pg_catalog.pg_database db on db.oid = l.database + left join pg_catalog.pg_stat_activity psa on l.pid = psa.pid +where not l.pid = pg_backend_pid(); +$$ + language sql + immutable + parallel safe + strict; + +create or replace function public.table_sizes() + returns table + ( + oid int, + table_schema text, + table_name text, + total_bytes numeric, + total_size text, + index_size text, + toast_size text, + table_size text + ) +as +$$ +select oid as oid, + table_schema as table_schema, + table_name as table_name, + total_bytes as total_bytes, + pg_size_pretty(total_bytes) as total_size, + pg_size_pretty(index_bytes) as index_size, + pg_size_pretty(toast_bytes) as toast_size, + pg_size_pretty(table_bytes) as table_size +from (select *, total_bytes - index_bytes - coalesce(toast_bytes, 0) as table_bytes + from (select c.oid as oid, + nspname as table_schema, + relname as table_name, + c.reltuples as row_estimate, + pg_total_relation_size(c.oid) as total_bytes, + pg_indexes_size(c.oid) as index_bytes, + pg_total_relation_size(reltoastrelid) as toast_bytes + from pg_class c + left join pg_namespace n on n.oid = c.relnamespace + where relkind = 'r') a) a +order by total_bytes desc; +$$ + language sql + immutable + parallel safe + strict; + +create or replace function public.index_utilization() + returns table + ( + table_name text, + idx_scans int, + seq_scans int, + index_usage int, + rows_in_table int + ) +as +$$ +select relname table_name, + idx_scan index_scan, + seq_scan table_scan, + 100 * idx_scan / (seq_scan + idx_scan) index_usage, + n_live_tup rows_in_table +from pg_stat_user_tables +where seq_scan + idx_scan > 0 +order by index_usage desc +limit 25; +$$ + language sql + immutable + parallel safe + strict; + +-- Graph helper functions +create or replace function public.kinds(target anyelement) returns text[] as +$$ +begin + if pg_typeof(target) = 'node'::regtype then + return (select array_agg(k.name) from kind k where k.id = any (target.kind_ids)); + elsif pg_typeof(target) = 'edge'::regtype then + return (select array_agg(k.name) from kind k where k.id = target.kind_id); + elsif pg_typeof(target) = 'int[]'::regtype then + return (select array_agg(k.name) from kind k where k.id = any (target::int2[])); + elsif pg_typeof(target) = 'int'::regtype then + return (select array_agg(k.name) from kind k where k.id = target::int2); + elsif pg_typeof(target) = 'int2[]'::regtype then + return (select array_agg(k.name) from kind k where k.id = any (target)); + elsif pg_typeof(target) = 'int2'::regtype then + return (select array_agg(k.name) from kind k where k.id = target); + else + raise exception 'Invalid argument type: %', pg_typeof(target) using hint = 'Type must be either node, edge, int[], int, int2[] or int2'; + end if; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create or replace function public.has_kind(target anyelement, variadic kind_name_in text[]) returns bool as +$$ +begin + if pg_typeof(target) = 'node'::regtype then + return exists(select 1 + where target.kind_ids operator (pg_catalog.&&) + (select array_agg(id) from kind k where k.name = any (kind_name_in))); + elsif pg_typeof(target) = 'edge'::regtype then + return exists(select 1 + where target.kind_id in (select id from kind k where k.name = any (kind_name_in))); + else + raise exception 'Invalid argument type: %', pg_typeof(target) using hint = 'Type must be either node or edge'; + end if; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create + or replace function public.get_node(id_in int4) + returns setof node as +$$ +select * +from node n +where n.id = id_in; +$$ + language sql immutable + parallel safe + strict; + +create + or replace function public.node_prop(target anyelement, property_name text) + returns jsonb as +$$ +begin + if pg_typeof(target) = 'node'::regtype then + return target.properties -> property_name; + elsif pg_typeof(target) = 'int4'::regtype then + return (select n.properties -> property_name from node n where n.id = target limit 1); + else + raise exception 'Invalid argument type: %', pg_typeof(target) using hint = 'Type must be either node or edge'; + end if; +end; +$$ + language plpgsql immutable + parallel safe + strict; + + +create or replace function public.mt_get_root(owner_object_id text) returns setof node as +$$ +select * +from node n +where has_kind(n, 'Meta') + and n.properties ->> 'system_tags' like '%admin_tier_0%' + and n.properties ->> 'owner_objectid' = owner_object_id; +$$ + language sql immutable + parallel safe + strict; + +-- All shortest path traversal functions and schema + +create or replace function public._format_asp_where_clause(root_criteria text, where_clause text) returns text as +$$ +declare + formatted_query text := ''; +begin + if length(root_criteria) > 0 then + if length(where_clause) > 0 then + formatted_query := ' where ' || root_criteria || ' and ' || where_clause; + else + formatted_query := ' where ' || root_criteria; + end if; + elsif length(where_clause) > 0 then + formatted_query := ' where ' || where_clause; + end if; + + return formatted_query; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create or replace function public._format_asp_query(terminal_criteria text, cycle_criteria text, + traversal_criteria text default '', + root_criteria text default '', + bind_pathspace bool default true, + bind_start bool default false, + bind_end bool default false) returns text as +$$ +declare + formatted_query text := 'insert into pathspace_next (path, next, is_terminal, is_cycle) '; +begin + if bind_pathspace then + formatted_query := + formatted_query || 'select p.path || r.id, r.end_id, ' || terminal_criteria || ', ' || cycle_criteria || + ' from edge r join pathspace_current p on p.next = r.start_id'; + else + formatted_query := formatted_query || 'select array [r.id]::int4[], r.end_id, ' || terminal_criteria || ', ' || + cycle_criteria || + ' from edge r'; + end if; + + if bind_start then + formatted_query := formatted_query || ' join node s on s.id = r.start_id'; + end if; + + if bind_end then + formatted_query := formatted_query || ' join node e on e.id = r.end_id '; + end if; + + formatted_query := formatted_query || _format_asp_where_clause(root_criteria, traversal_criteria) || ';'; + + raise notice '_format_asp_query -> %', formatted_query; + return formatted_query; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create or replace function public._all_shortest_paths(root_criteria text, + traversal_criteria text, + terminal_criteria text, + max_depth int4) + returns table + ( + path int4[] + ) +as +$$ +declare + has_root_criteria bool := length(root_criteria) > 0; + has_traversal_criteria bool := length(traversal_criteria) > 0; + has_terminal_criteria bool := length(terminal_criteria) > 0; + + -- Make sure to take into account if queries will need the start or end node of edges bound by a join + bind_root_node bool := has_root_criteria and root_criteria like '%s.%'; + bind_terminal_node bool := has_terminal_criteria and terminal_criteria like '%e.%'; + bind_traversal_start_node bool := has_traversal_criteria and traversal_criteria like '%s.%'; + bind_traversal_end_node bool := has_traversal_criteria and traversal_criteria like '%e.%'; + depth int4 := 1; +begin + -- Create two unlogged (no WAL writes) temporary tables (invisible to other sessions) for storing traversal + -- fronts during path expansion. + create temporary table pathspace_current + ( + path int4[] not null, + next int4 not null, + is_terminal bool not null, + is_cycle bool not null, + primary key (path) + ) on commit drop; + + create temporary table pathspace_next + ( + path int4[] not null, + next int4 not null, + is_terminal bool not null, + is_cycle bool not null, + primary key (path) + ) on commit drop; + + -- Create an index on the next node ID to accelerate joins + create index if not exists pathspace_current_next_index on pathspace_current using btree (next); + create index if not exists pathspace_next_next_index on pathspace_next using btree (next); + + -- Create an index on the is_terminal boolean to accelerate aggregation and selection + create index if not exists pathspace_current_terminal_index on pathspace_current using btree (is_terminal); + create index if not exists pathspace_next_terminal_index on pathspace_next using btree (is_terminal); + + -- Initial expansion to acquire the first traversal front + execute _format_asp_query(terminal_criteria := terminal_criteria, + cycle_criteria := 'r.start_id = r.end_id', + bind_pathspace := false, + bind_start := bind_traversal_start_node or bind_root_node, + bind_end := bind_traversal_end_node or bind_terminal_node, + root_criteria := root_criteria, + traversal_criteria := traversal_criteria); + + -- Copy from the next pathspace table to the current pathspace table. Any non-terminal cycles are omitted as + -- part of this copy to prune visited branches. + insert into pathspace_current select * from pathspace_next p where not p.is_cycle or p.is_terminal; + + -- Truncate the next pathspace table to clear it + truncate pathspace_next; + + -- Loop until either the current depth exceeds the max allowed depth or if any of the paths are terminal + while depth < max_depth and (select count(*) from pathspace_current p where p.is_terminal) = 0 + loop + -- Increase the depth counter as we're expanding a new front + depth := depth + 1; + + -- Perform the next pathspace expansion + execute _format_asp_query(terminal_criteria := terminal_criteria, + cycle_criteria := 'r.id = any (p.path)', + bind_pathspace := true, + bind_start := bind_traversal_start_node, + bind_end := bind_terminal_node or bind_traversal_end_node, + traversal_criteria := traversal_criteria); + + -- Truncate the old pathspace table to clear it + truncate pathspace_current; + + -- Copy from the next pathspace table to the current pathspace table. Any non-terminal cycles are omitted as + -- part of this copy to prune visited branches. + insert into pathspace_current select * from pathspace_next p where not p.is_cycle or p.is_terminal; + + -- Truncate the next pathspace table to clear it + truncate pathspace_next; + end loop; + + -- Return the raw path (set of edge IDs) for each path found in pathspace + return query select p.path + from pathspace_current p + -- Select only terminal paths + where p.is_terminal; + + -- Close the set + return; +end; +$$ + language plpgsql volatile + strict; + +create or replace function public.edges_to_path(path variadic int4[]) returns pathComposite as +$$ +select row (array_agg(distinct (n.id, n.kind_ids, n.properties)::nodeComposite)::nodeComposite[], + array_agg(distinct (r.id, r.start_id, r.end_id, r.kind_id, r.properties)::edgeComposite)::edgeComposite[])::pathComposite +from edge r + join node n on n.id = r.start_id or n.id = r.end_id +where r.id = any (path); +$$ + language sql + immutable + parallel safe + strict; + +create or replace function public.all_shortest_paths(root_criteria text, + traversal_criteria text, + terminal_criteria text, + max_depth int4) + returns pathComposite +as +$$ +declare + paths pathcomposite; +begin + select array_agg(distinct (n.id, n.kind_ids, n.properties)::nodeComposite)::nodeComposite[], + array_agg(distinct + (r.id, r.start_id, r.end_id, r.kind_id, r.properties)::edgeComposite)::edgeComposite[] + into paths + from _all_shortest_paths(root_criteria, traversal_criteria, terminal_criteria, + max_depth) as t + join edge r on r.id = any (t.path) + join node n on n.id = r.start_id or n.id = r.end_id; + + return paths; +end; +$$ + language plpgsql + immutable + strict; + +-- Generic traversal functions and schema +do +$$ + begin + create type _traversal_step as + ( + root_criteria text, + traversal_criteria text, + terminal_criteria text, + max_depth integer + ); + exception + when duplicate_object then null; + end +$$; + +create or replace function public.traversal_step(root_criteria text default '', + traversal_criteria text default '', + terminal_criteria text default '', + max_depth integer default 0) + returns _traversal_step as +$$ +begin + return (root_criteria, traversal_criteria, terminal_criteria, max_depth)::_traversal_step; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create or replace function public._format_traversal_continuation_termination(terminal_criteria text, + bind_traversal_start_node bool, + bind_traversal_end_node bool) + returns text as +$$ +declare + formatted_query text := 'update pathspace_current p set terminal = true from edge r'; + where_clause text := ' where not p.terminal and p.exhausted and r.start_id = p.path[array_length(p.path, 1)]'; +begin + if bind_traversal_start_node then + formatted_query := formatted_query || ', node s'; + where_clause := where_clause || ' and s.id = r.start_id'; + end if; + + if bind_traversal_end_node then + formatted_query := formatted_query || ', node e'; + where_clause := where_clause || ' and e.id = r.end_id'; + end if; + + return formatted_query || where_clause || ' and ' || terminal_criteria; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create or replace function public._format_traversal_query(traversal_criteria text, + terminal_criteria text, + mark_terminal bool, + bind_traversal_start_node bool, + bind_traversal_end_node bool) + returns text as +$$ +declare + formatted_query text := 'with inserts as (insert into pathspace_next (path, next, terminal, exhausted, rejected) '; +begin + formatted_query := formatted_query || 'select $1 || r.id, r.end_id, '; + + if length(terminal_criteria) > 0 then + formatted_query := formatted_query || terminal_criteria || ', '; + else + formatted_query := formatted_query || mark_terminal || ', '; + end if; + + formatted_query := + formatted_query || 'false, r.id = any ($1) from edge r '; + + if bind_traversal_start_node then + formatted_query := formatted_query || ' join node s on s.id = r.start_id'; + end if; + + if bind_traversal_end_node then + formatted_query := formatted_query || ' join node e on e.id = r.end_id '; + end if; + + formatted_query := formatted_query || ' where r.start_id = $2'; + + if length(traversal_criteria) > 0 then + formatted_query := formatted_query || ' and ' || traversal_criteria; + end if; + + return formatted_query || ' returning true) select count(*) from inserts;'; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create or replace function public._format_traversal_initial_query(root_criteria text, + terminal_criteria text, + mark_terminal bool, + traversal_criteria text, + bind_root_node bool, + bind_terminal_node bool) returns text as +$$ +declare + formatted_query text := 'insert into pathspace_current (path, next, terminal, exhausted, rejected) '; +begin + formatted_query := formatted_query || 'select array [r.id]::int4[], r.end_id, '; + + if length(terminal_criteria) > 0 then + formatted_query := formatted_query || terminal_criteria || ', '; + else + formatted_query := formatted_query || mark_terminal || ', '; + end if; + + formatted_query := formatted_query || 'false, r.start_id = r.end_id from edge r'; + + if bind_root_node then + formatted_query := formatted_query || ' join node s on s.id = r.start_id'; + end if; + + if bind_terminal_node then + formatted_query := formatted_query || ' join node e on e.id = r.end_id '; + end if; + + if length(root_criteria) > 0 then + if length(traversal_criteria) > 0 then + formatted_query := formatted_query || ' where ' || root_criteria || ' and ' || traversal_criteria; + else + formatted_query := formatted_query || ' where ' || root_criteria; + end if; + elsif length(traversal_criteria) > 0 then + formatted_query := formatted_query || ' where ' || traversal_criteria; + end if; + + return formatted_query; +end; +$$ + language plpgsql immutable + parallel safe + strict; + +create or replace function public.expand_traversal_step(step _traversal_step, + continuation bool default false, + last_continuation bool default false) + returns void as +$$ +declare + incomplete_path record; + num_expansions int8; + has_root_criteria bool := length(step.root_criteria) > 0; + has_traversal_criteria bool := length(step.traversal_criteria) > 0; + has_terminal_criteria bool := length(step.terminal_criteria) > 0; + + -- Make sure to take into account if queries will need the start or end node of edges bound by a join + bind_root_node bool := has_root_criteria and step.root_criteria like '%s.%'; + bind_terminal_node bool := has_terminal_criteria and step.terminal_criteria like '%e.%'; + bind_traversal_start_node bool := has_traversal_criteria and step.traversal_criteria like '%s.%'; + bind_traversal_end_node bool := has_traversal_criteria and step.traversal_criteria like '%e.%'; + depth int4 := 0; +begin + if not continuation then + raise notice 'Starting a new traversal'; + + -- Increase the depth counter as we're expanding a new front + depth := depth + 1; + + -- Perform the initial expansion to acquire the first traversal front + execute _format_traversal_initial_query( + root_criteria := step.root_criteria, + bind_root_node := bind_root_node or bind_traversal_start_node, + terminal_criteria := step.terminal_criteria, + mark_terminal := last_continuation and not has_terminal_criteria and depth = step.max_depth, + bind_terminal_node := bind_terminal_node or bind_traversal_end_node, + traversal_criteria := step.traversal_criteria); + + -- Copy from the next pathspace table to the current pathspace table and omit rejected segments. + insert into pathspace_current select * from pathspace_next p where not p.rejected or p.terminal; + + -- Truncate the next pathspace table to clear it + truncate pathspace_next; + else + raise notice 'Continuing traversal'; + + if last_continuation then + raise notice 'This is the last continuation.'; + end if; + + if has_terminal_criteria then + -- If this is a traversal continuation then we must validate any exhausted paths that may also be terminal + execute _format_traversal_continuation_termination( + terminal_criteria := step.terminal_criteria, + bind_traversal_start_node := bind_traversal_start_node, + bind_traversal_end_node := bind_terminal_node or bind_traversal_end_node); + + -- Dump any paths that are not terminal but exhausted as this will prune pathspace to only paths that are + -- eligible for further expansion + delete from pathspace_current p where not p.terminal and p.exhausted; + else + -- Mark all non-terminal paths as no longer exhausted as this is a continuation + update pathspace_current p set exhausted = false where not p.terminal; + end if; + end if; + + raise notice 'Current pathspace:'; + + for incomplete_path in select * from pathspace_current p + loop + raise notice 'path: % - terminal: %, exhausted: %, rejected: %', incomplete_path.path, incomplete_path.terminal, incomplete_path.exhausted, incomplete_path.rejected; + end loop; + + -- Loop until either the current depth exceeds the max allowed depth or if any of the paths are terminal + while depth < step.max_depth and + exists(select true from pathspace_current p where not p.terminal and not p.exhausted) + loop + raise notice 'Incomplete paths:'; + + for incomplete_path in select * from pathspace_current p where not p.terminal and not p.exhausted + loop + raise notice 'path: % - terminal: %, exhausted: %, rejected: %', incomplete_path.path, incomplete_path.terminal, incomplete_path.exhausted, incomplete_path.rejected; + end loop; + + -- Increase the depth counter as we're expanding a new front + depth := depth + 1; + + -- Copy all terminal segments + insert into pathspace_next select * from pathspace_current p where p.terminal; + + -- Expand all non-terminal, unexhausted segments + for incomplete_path in select * from pathspace_current p where not p.terminal and not p.exhausted + loop + -- Expand the next front for this segment + execute _format_traversal_query( + traversal_criteria := step.traversal_criteria, + terminal_criteria := step.terminal_criteria, + mark_terminal := last_continuation and not has_terminal_criteria and depth = step.max_depth, + bind_traversal_start_node := bind_traversal_start_node, + bind_traversal_end_node := bind_terminal_node or bind_traversal_end_node) + into num_expansions + using incomplete_path.path, incomplete_path.next; + + if num_expansions = 0 then + -- If there were no more expansions for this segment, insert into the next pathspace it as + -- exhausted. The terminal status of the segment may be set to true if this is the last + -- traversal continuation and there is no terminal criteria set. + insert into pathspace_next (path, next, terminal, exhausted, rejected) + values (incomplete_path.path, 0, last_continuation and not has_terminal_criteria, true, + false); + end if; + end loop; + + -- Truncate the old pathspace table to clear it + truncate pathspace_current; + + -- Copy from the next pathspace into the current pathspace + insert into pathspace_current select * from pathspace_next p where not p.rejected; + + -- Truncate the next pathspace table to clear it + truncate pathspace_next; + end loop; + + raise notice 'Step pathspace:'; + for incomplete_path in select * from pathspace_current p + loop + raise notice 'path: % - terminal: %, exhausted: %, rejected: %', incomplete_path.path, incomplete_path.terminal, incomplete_path.exhausted, incomplete_path.rejected; + end loop; + + return; +end; +$$ + language plpgsql volatile + strict; + +create or replace function public.traverse(steps variadic _traversal_step[]) + returns table + ( + path int4[][] + ) +as +$$ +declare + step_idx int4 = 0; + next_step _traversal_step; +begin + -- Create two unlogged (no WAL writes) temporary tables (invisible to other sessions) for storing traversal + -- fronts during path expansion. + create temporary table pathspace_current + ( + path int4[] not null, + next int4 not null, + terminal bool not null, + exhausted bool not null, + rejected bool not null, + primary key (path) + ) on commit drop; + + create temporary table pathspace_next + ( + path int4[] not null, + next int4 not null, + terminal bool not null, + exhausted bool not null, + rejected bool not null, + primary key (path) + ) on commit drop; + + -- Iterate through the traversal steps + foreach next_step in array steps + loop + step_idx := step_idx + 1; + + raise notice 'Array length: % - Is last continuation: %', array_length(steps, 1), step_idx = array_length(steps, 1); + + perform expand_traversal_step( + step := next_step, + continuation := step_idx > 1, + last_continuation := step_idx = array_length(steps, 1)); + end loop; + + -- Return the paths + return query select p.path from pathspace_current p where p.terminal; + + -- Close the set + return; +end; +$$ + language plpgsql volatile + strict; diff --git a/packages/go/dawgs/drivers/pg/query/sql/select_graph_by_name.sql b/packages/go/dawgs/drivers/pg/query/sql/select_graph_by_name.sql new file mode 100644 index 0000000000..fbace92665 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/select_graph_by_name.sql @@ -0,0 +1,20 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- Selects the ID of a graph with the given name. +select id +from graph +where name = @name; diff --git a/packages/go/dawgs/drivers/pg/query/sql/select_graphs.sql b/packages/go/dawgs/drivers/pg/query/sql/select_graphs.sql new file mode 100644 index 0000000000..5cfff07344 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/select_graphs.sql @@ -0,0 +1,19 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- Selects all defined graphs in the database. +select id, name +from graph; diff --git a/packages/go/dawgs/drivers/pg/query/sql/select_kind_id.sql b/packages/go/dawgs/drivers/pg/query/sql/select_kind_id.sql new file mode 100644 index 0000000000..7f9086ed47 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/select_kind_id.sql @@ -0,0 +1,20 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- Selects the ID of a given Kind by name +select id +from kind +where name = @name; diff --git a/packages/go/dawgs/drivers/pg/query/sql/select_kinds.sql b/packages/go/dawgs/drivers/pg/query/sql/select_kinds.sql new file mode 100644 index 0000000000..040903aad9 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/select_kinds.sql @@ -0,0 +1,19 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- Selects all defined Kinds currently present in the database. +select id, name +from kind; diff --git a/packages/go/dawgs/drivers/pg/query/sql/select_table_indexes.sql b/packages/go/dawgs/drivers/pg/query/sql/select_table_indexes.sql new file mode 100644 index 0000000000..33f01f9bd9 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/query/sql/select_table_indexes.sql @@ -0,0 +1,21 @@ +-- Copyright 2023 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- List all indexes for a given table name. +select indexdef +from pg_indexes +where schemaname = 'public' + and tablename = @tablename; diff --git a/packages/go/dawgs/drivers/pg/relationship.go b/packages/go/dawgs/drivers/pg/relationship.go new file mode 100644 index 0000000000..aa425fe7bf --- /dev/null +++ b/packages/go/dawgs/drivers/pg/relationship.go @@ -0,0 +1,257 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "fmt" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/dawgs/query" +) + +func directionToReturnCriteria(direction graph.Direction) (graph.Criteria, error) { + switch direction { + case graph.DirectionInbound: + // Select the relationship and the end node + return query.Returning( + query.Relationship(), + query.End(), + ), nil + + case graph.DirectionOutbound: + // Select the relationship and the start node + return query.Returning( + query.Relationship(), + query.Start(), + ), nil + + default: + return nil, fmt.Errorf("bad direction: %d", direction) + } +} + +type relationshipQuery struct { + liveQuery +} + +func (s *relationshipQuery) Filter(criteria graph.Criteria) graph.RelationshipQuery { + s.queryBuilder.Apply(query.Where(criteria)) + return s +} + +func (s *relationshipQuery) Filterf(criteriaDelegate graph.CriteriaProvider) graph.RelationshipQuery { + return s.Filter(criteriaDelegate()) +} + +func (s *relationshipQuery) Delete() error { + return s.exec(query.Delete( + query.Relationship(), + )) +} + +func (s *relationshipQuery) Update(properties *graph.Properties) error { + return s.exec(query.Updatef(func() graph.Criteria { + var updateStatements []graph.Criteria + + if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { + updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) + } + + if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { + updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) + } + + return updateStatements + })) +} + +func (s *relationshipQuery) OrderBy(criteria ...graph.Criteria) graph.RelationshipQuery { + s.queryBuilder.Apply(query.OrderBy(criteria...)) + return s +} + +func (s *relationshipQuery) Offset(offset int) graph.RelationshipQuery { + s.queryBuilder.Apply(query.Offset(offset)) + return s +} + +func (s *relationshipQuery) Limit(limit int) graph.RelationshipQuery { + s.queryBuilder.Apply(query.Limit(limit)) + return s +} + +func (s *relationshipQuery) Count() (int64, error) { + var count int64 + + return count, s.Query(func(results graph.Result) error { + if !results.Next() { + return graph.ErrNoResultsFound + } + + return results.Scan(&count) + }, query.Returning( + query.Count(query.Relationship()), + )) +} + +// TODO: Max depth is relying on an uninformed default and should be passed either with criteria as an AST node or as an explicit parameter to this function +func (s *relationshipQuery) FetchAllShortestPaths(delegate func(cursor graph.Cursor[graph.Path]) error) error { + result := s.runAllShortestPathsQuery() + defer result.Close() + + if result.Error() != nil { + return result.Error() + } + + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.Path, error) { + var path graph.Path + return path, scanner.Scan(&path) + }) + defer cursor.Close() + + return delegate(cursor) +} + +func (s *relationshipQuery) FetchTriples(delegate func(cursor graph.Cursor[graph.RelationshipTripleResult]) error) error { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.RelationshipTripleResult, error) { + var ( + startID graph.ID + relationshipID graph.ID + endID graph.ID + err = scanner.Scan(&startID, &relationshipID, &endID) + ) + + return graph.RelationshipTripleResult{ + ID: relationshipID, + StartID: startID, + EndID: endID, + }, err + }) + + defer cursor.Close() + return delegate(cursor) + }, query.ReturningDistinct( + query.StartID(), + query.RelationshipID(), + query.EndID(), + )) +} + +func (s *relationshipQuery) FetchKinds(delegate func(cursor graph.Cursor[graph.RelationshipKindsResult]) error) error { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.RelationshipKindsResult, error) { + var ( + startID graph.ID + relationshipID graph.ID + relationshipKind graph.Kind + endID graph.ID + err = scanner.Scan(&startID, &relationshipID, &relationshipKind, &endID) + ) + + return graph.RelationshipKindsResult{ + RelationshipTripleResult: graph.RelationshipTripleResult{ + ID: relationshipID, + StartID: startID, + EndID: endID, + }, + Kind: relationshipKind, + }, err + }) + + defer cursor.Close() + return delegate(cursor) + }, query.Returning( + query.StartID(), + query.RelationshipID(), + query.KindsOf(query.Relationship()), + query.EndID(), + )) +} + +func (s *relationshipQuery) First() (*graph.Relationship, error) { + var relationship graph.Relationship + + return &relationship, s.Query( + func(results graph.Result) error { + if !results.Next() { + return graph.ErrNoResultsFound + } + + return results.Scan(&relationship) + }, + query.Returning( + query.Relationship(), + ), + query.Limit(1), + ) +} + +func (s *relationshipQuery) Fetch(delegate func(cursor graph.Cursor[*graph.Relationship]) error) error { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (*graph.Relationship, error) { + var relationship graph.Relationship + return &relationship, scanner.Scan(&relationship) + }) + + defer cursor.Close() + return delegate(cursor) + }, query.Returning( + query.Relationship(), + )) +} + +func (s *relationshipQuery) FetchDirection(direction graph.Direction, delegate func(cursor graph.Cursor[graph.DirectionalResult]) error) error { + if returnCriteria, err := directionToReturnCriteria(direction); err != nil { + return err + } else { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.DirectionalResult, error) { + var ( + relationship graph.Relationship + node graph.Node + ) + + if err := scanner.Scan(&relationship, &node); err != nil { + return graph.DirectionalResult{}, err + } + + return graph.DirectionalResult{ + Direction: direction, + Relationship: &relationship, + Node: &node, + }, nil + }) + + defer cursor.Close() + return delegate(cursor) + }, returnCriteria) + } +} + +func (s *relationshipQuery) FetchIDs(delegate func(cursor graph.Cursor[graph.ID]) error) error { + return s.Query(func(result graph.Result) error { + cursor := graph.NewResultIterator(s.ctx, result, func(scanner graph.Scanner) (graph.ID, error) { + var relationshipID graph.ID + return relationshipID, scanner.Scan(&relationshipID) + }) + + defer cursor.Close() + return delegate(cursor) + }, query.Returning( + query.RelationshipID(), + )) +} diff --git a/packages/go/dawgs/drivers/pg/result.go b/packages/go/dawgs/drivers/pg/result.go new file mode 100644 index 0000000000..761bd23513 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/result.go @@ -0,0 +1,117 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "fmt" + "github.com/jackc/pgx/v5" + "github.com/specterops/bloodhound/dawgs/graph" +) + +type queryResult struct { + rows pgx.Rows + kindMapper KindMapper +} + +func (s *queryResult) Next() bool { + return s.rows.Next() +} + +func (s *queryResult) Values() (graph.ValueMapper, error) { + if values, err := s.rows.Values(); err != nil { + return nil, err + } else { + return NewValueMapper(values, s.kindMapper), nil + } +} + +func (s *queryResult) Scan(targets ...any) error { + pgTargets := make([]any, 0, len(targets)) + + for _, target := range targets { + switch target.(type) { + case *graph.Path: + pgTargets = append(pgTargets, &pathComposite{}) + + case *graph.Relationship: + pgTargets = append(pgTargets, &edgeComposite{}) + + case *graph.Node: + pgTargets = append(pgTargets, &nodeComposite{}) + + case *graph.Kind: + pgTargets = append(pgTargets, new(int16)) + + case *graph.Kinds: + pgTargets = append(pgTargets, &[]int16{}) + + default: + pgTargets = append(pgTargets, target) + } + } + + if err := s.rows.Scan(pgTargets...); err != nil { + return err + } + + for idx, pgTarget := range pgTargets { + switch typedPGTarget := pgTarget.(type) { + case *pathComposite: + if err := typedPGTarget.ToPath(s.kindMapper, targets[idx].(*graph.Path)); err != nil { + return err + } + + case *edgeComposite: + if err := typedPGTarget.ToRelationship(s.kindMapper, targets[idx].(*graph.Relationship)); err != nil { + return err + } + + case *nodeComposite: + if err := typedPGTarget.ToNode(s.kindMapper, targets[idx].(*graph.Node)); err != nil { + return err + } + + case *int16: + if kindPtr, isKindType := targets[idx].(*graph.Kind); isKindType { + if kind, hasKind := s.kindMapper.MapKindID(*typedPGTarget); !hasKind { + return fmt.Errorf("unable to map kind ID %d", *typedPGTarget) + } else { + *kindPtr = kind + } + } + + case *[]int16: + if kindsPtr, isKindsType := targets[idx].(*graph.Kinds); isKindsType { + if kinds, missingKindIDs := s.kindMapper.MapKindIDs(*typedPGTarget...); len(missingKindIDs) > 0 { + return fmt.Errorf("unable to map kind IDs %+v", missingKindIDs) + } else { + *kindsPtr = kinds + } + } + } + } + + return nil +} + +func (s *queryResult) Error() error { + return s.rows.Err() +} + +func (s *queryResult) Close() { + s.rows.Close() +} diff --git a/packages/go/dawgs/drivers/pg/statements.go b/packages/go/dawgs/drivers/pg/statements.go new file mode 100644 index 0000000000..a18a00a9a8 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/statements.go @@ -0,0 +1,56 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +const ( + fetchNodeStatement = `select kinds, properties from node where node.id = $1;` + fetchNodeSliceStatement = `select id, kinds, properties from node where node.id = any($1);` + createNodeStatement = `insert into node (graph_id, kind_ids, properties) values ($1, $2, $3) returning id;` + createNodeWithoutIDBatchStatement = `insert into node (graph_id, kind_ids, properties) select $1, unnest($2::text[])::int2[], unnest($3::jsonb[])` + createNodeWithIDBatchStatement = `insert into node (graph_id, id, kind_ids, properties) select $1, unnest($2::int4[]), unnest($3::text[])::int2[], unnest($4::jsonb[])` + deleteNodeStatement = `delete from node where node.id = $1` + deleteNodeWithIDStatement = `delete from node where node.id = any($1)` + upsertNodeStatement = `insert into node (graph_id, )` + + nodePropertySetOnlyStatement = `update node set kind_ids = $1, properties = properties || $2::jsonb where node.id = $3` + nodePropertyDeleteOnlyStatement = `update node set kind_ids = $1, properties = properties - $2::text[] where node.id = $3` + nodePropertySetAndDeleteStatement = `update node set kind_ids = $1, properties = properties || $2::jsonb - $3::text[]) where node.id = $4` + + fetchEdgeStatement = `select start_id, end_id, kind, properties from relationships where relationships.id = $1;` + fetchEdgeSliceStatement = `select id, start_id, end_id, kind, properties from node where relationships.id = any($1);` + createEdgeStatement = `insert into edge (graph_id, start_id, end_id, kind_id, properties) values ($1, $2, $3, $4, $5) returning id;` + createEdgeBatchStatement = `merge into edge as e using (select $1::int4 as gid, unnest($2::int4[]) as sid, unnest($3::int4[]) as eid, unnest($4::int2[]) as kid, unnest($5::jsonb[]) as p) as ei on e.start_id = ei.sid and e.end_id = ei.eid and e.kind_id = ei.kid when matched then update set properties = e.properties || ei.p when not matched then insert (graph_id, start_id, end_id, kind_id, properties) values (ei.gid, ei.sid, ei.eid, ei.kid, ei.p);` + deleteEdgeStatement = `delete from edge as e where e.id = $1` + deleteEdgeWithIDStatement = `delete from edge as e where e.id = any($1)` + + edgePropertySetOnlyStatement = `update edge set properties = properties || $1::jsonb where edge.id = $2` + edgePropertyDeleteOnlyStatement = `update edge set properties = properties - $1::text[] where edge.id = $2` + edgePropertySetAndDeleteStatement = `update edge set properties = properties || $1::jsonb - $2::text[] where edge.id = $3` + + createNodesAndEdgeStatement = `with start_node as (insert into node (kinds, properties) values ($1, $2) returning id), +end_node as (insert into node (kinds, properties) values ($3, $4) returning id) + +insert into relationships (start_id, end_id, kind, properties) values((select id from start_node), (select id from end_node), $5, $6);` + + createStartNodeAndEdgeStatement = `with start_node as (insert into node (kinds, properties) values ($1, $2) returning id) + +insert into relationships (start_id, end_id, kind, properties) values((select id from start_node), $3, $4, $5);` + + createEndNodeAndEdgeStatement = `with end_node as (insert into node (kinds, properties) values ($1, $2) returning id) + +insert into relationships (start_id, end_id, kind, properties) values($3, (select id from end_node), $4, $5);` +) diff --git a/packages/go/dawgs/drivers/pg/tooling.go b/packages/go/dawgs/drivers/pg/tooling.go new file mode 100644 index 0000000000..12171b7ce4 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/tooling.go @@ -0,0 +1,142 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "github.com/specterops/bloodhound/dawgs/drivers" + "github.com/specterops/bloodhound/log" + "regexp" + "sync" +) + +type IterationOptions interface { + Once() +} + +type QueryHookOptions interface { + Trace() IterationOptions +} + +type QueryHook interface { + OnStatementMatch(statement string) QueryHookOptions + OnStatementRegex(re *regexp.Regexp) QueryHookOptions +} + +type actionType int + +const ( + actionTrace actionType = iota +) + +type queryHook struct { + statementMatch *string + statementRegex *regexp.Regexp + action actionType + actionIterations int +} + +func (s *queryHook) Execute(query string, arguments ...any) { + switch s.action { + case actionTrace: + log.Infof("Here") + } +} + +func (s *queryHook) Catches(query string, arguments ...any) bool { + if s.statementMatch != nil { + if query == *s.statementMatch { + return true + } + } + + if s.statementRegex != nil { + if s.statementRegex.MatchString(query) { + return true + } + } + + return false +} + +func (s *queryHook) Once() { + s.actionIterations = 1 +} + +func (s *queryHook) Times(actionIterations int) { + s.actionIterations = actionIterations +} + +func (s *queryHook) Trace() IterationOptions { + s.action = actionTrace + return s +} + +func (s *queryHook) OnStatementMatch(statement string) QueryHookOptions { + s.statementMatch = &statement + return s +} + +func (s *queryHook) OnStatementRegex(re *regexp.Regexp) QueryHookOptions { + s.statementRegex = re + return s +} + +type QueryPathInspector interface { + Hook() QueryHook +} + +type queryPathInspector struct { + hooks []*queryHook + lock *sync.RWMutex +} + +func (s *queryPathInspector) Inspect(query string, arguments ...any) { + if !drivers.IsQueryAnalysisEnabled() { + return + } + + s.lock.RLock() + defer s.lock.RUnlock() + + for _, hook := range s.hooks { + if hook.Catches(query, arguments) { + hook.Execute(query, arguments) + } + } +} + +func (s *queryPathInspector) Hook() QueryHook { + s.lock.Lock() + defer s.lock.Unlock() + + hook := &queryHook{} + s.hooks = append(s.hooks, hook) + + return hook +} + +var inspectorInst = &queryPathInspector{ + lock: &sync.RWMutex{}, +} + +func inspector() *queryPathInspector { + return inspectorInst +} + +func Inspector() QueryPathInspector { + return inspectorInst +} diff --git a/packages/go/dawgs/drivers/pg/transaction.go b/packages/go/dawgs/drivers/pg/transaction.go new file mode 100644 index 0000000000..2683b40f73 --- /dev/null +++ b/packages/go/dawgs/drivers/pg/transaction.go @@ -0,0 +1,321 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "bytes" + "context" + "fmt" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/specterops/bloodhound/cypher/backend/pgsql" + "github.com/specterops/bloodhound/cypher/frontend" + "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/dawgs/query" + "github.com/specterops/bloodhound/dawgs/util/size" +) + +type driver interface { + Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) + Query(ctx context.Context, sql string, arguments ...any) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, arguments ...any) pgx.Row +} + +type inspectingDriver struct { + upstreamDriver driver +} + +func (s inspectingDriver) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) { + inspector().Inspect(sql, arguments) + return s.upstreamDriver.Exec(ctx, sql, arguments...) +} + +func (s inspectingDriver) Query(ctx context.Context, sql string, arguments ...any) (pgx.Rows, error) { + inspector().Inspect(sql, arguments) + return s.upstreamDriver.Query(ctx, sql, arguments...) +} + +func (s inspectingDriver) QueryRow(ctx context.Context, sql string, arguments ...any) pgx.Row { + inspector().Inspect(sql, arguments) + return s.upstreamDriver.QueryRow(ctx, sql, arguments...) +} + +func newInspectingDriver(upstreamDriver driver) driver { + return &inspectingDriver{ + upstreamDriver: upstreamDriver, + } +} + +type transaction struct { + schemaManager *SchemaManager + queryExecMode pgx.QueryExecMode + queryResultsFormat pgx.QueryResultFormats + ctx context.Context + conn *pgxpool.Conn + tx pgx.Tx + targetSchema graph.Graph + targetSchemaSet bool +} + +func newTransaction(ctx context.Context, conn *pgxpool.Conn, schemaManager *SchemaManager, cfg *Config) (*transaction, error) { + if pgxTx, err := conn.BeginTx(ctx, cfg.Options); err != nil { + return nil, err + } else { + return &transaction{ + schemaManager: schemaManager, + queryExecMode: cfg.QueryExecMode, + queryResultsFormat: cfg.QueryResultFormats, + ctx: ctx, + conn: conn, + tx: pgxTx, + targetSchemaSet: false, + }, nil + } +} + +func (s *transaction) driver() driver { + if s.tx != nil { + return inspectingDriver{ + upstreamDriver: s.tx, + } + } + + return inspectingDriver{ + upstreamDriver: s.conn, + } +} + +func (s *transaction) TraversalMemoryLimit() size.Size { + return size.Gibibyte +} + +func (s *transaction) WithGraph(schema graph.Graph) graph.Transaction { + s.targetSchema = schema + s.targetSchemaSet = true + + return s +} + +func (s *transaction) Close() { + if s.tx != nil { + s.tx.Rollback(s.ctx) + s.tx = nil + } +} + +func (s *transaction) getTargetGraph() (model.Graph, error) { + if !s.targetSchemaSet { + // Look for a default graph target + if defaultGraph, hasDefaultGraph := s.schemaManager.DefaultGraph(); !hasDefaultGraph { + return model.Graph{}, fmt.Errorf("driver operation requires a graph target to be set") + } else { + return defaultGraph, nil + } + } + + return s.schemaManager.AssertGraph(s, s.targetSchema) +} + +func (s *transaction) CreateNode(properties *graph.Properties, kinds ...graph.Kind) (*graph.Node, error) { + if graphTarget, err := s.getTargetGraph(); err != nil { + return nil, err + } else if kindIDSlice, err := s.schemaManager.AssertKinds(s, kinds); err != nil { + return nil, err + } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(properties); err != nil { + return nil, err + } else { + var ( + nodeID int32 + result = s.queryRow(createNodeStatement, s.queryExecMode, graphTarget.ID, kindIDSlice, propertiesJSONB) + ) + + if err := result.Scan(&nodeID); err != nil { + return nil, err + } + + return graph.NewNode(graph.ID(nodeID), properties, kinds...), nil + } +} + +func (s *transaction) UpdateNode(node *graph.Node) error { + var ( + properties = node.Properties + updateStatements []graph.Criteria + ) + + if addedKinds := node.AddedKinds; len(addedKinds) > 0 { + updateStatements = append(updateStatements, query.AddKinds(query.Node(), addedKinds)) + } + + if deletedKinds := node.DeletedKinds; len(deletedKinds) > 0 { + updateStatements = append(updateStatements, query.DeleteKinds(query.Node(), deletedKinds)) + } + + if modifiedProperties := properties.ModifiedProperties(); len(modifiedProperties) > 0 { + updateStatements = append(updateStatements, query.SetProperties(query.Node(), modifiedProperties)) + } + + if deletedProperties := properties.DeletedProperties(); len(deletedProperties) > 0 { + updateStatements = append(updateStatements, query.DeleteProperties(query.Node(), deletedProperties...)) + } + + return s.Nodes().Filter(query.Equals(query.NodeID(), node.ID)).Query(func(results graph.Result) error { + // We don't need to exhaust the result set as the defered close with discard it for us + return results.Error() + }, updateStatements...) +} + +func (s *transaction) Nodes() graph.NodeQuery { + return &nodeQuery{ + liveQuery: newLiveQuery(s.ctx, s, s.schemaManager), + } +} + +func (s *transaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { + if graphTarget, err := s.getTargetGraph(); err != nil { + return nil, err + } else if kindIDSlice, err := s.schemaManager.AssertKinds(s, graph.Kinds{kind}); err != nil { + return nil, err + } else if propertiesJSONB, err := pgsql.PropertiesToJSONB(properties); err != nil { + return nil, err + } else { + var ( + edgeID int32 + result = s.queryRow(createEdgeStatement, s.queryExecMode, graphTarget.ID, startNodeID, endNodeID, kindIDSlice[0], propertiesJSONB) + ) + + if err := result.Scan(&edgeID); err != nil { + return nil, err + } + + return graph.NewRelationship(graph.ID(edgeID), startNodeID, endNodeID, properties, kind), nil + } +} + +func (s *transaction) UpdateRelationship(relationship *graph.Relationship) error { + var ( + modifiedProperties = relationship.Properties.ModifiedProperties() + deletedProperties = relationship.Properties.DeletedProperties() + numModifiedProperties = len(modifiedProperties) + numDeletedProperties = len(deletedProperties) + + statement string + arguments []any + ) + + if numModifiedProperties > 0 { + if jsonbArgument, err := pgsql.ValueToJSONB(modifiedProperties); err != nil { + return err + } else { + arguments = append(arguments, jsonbArgument) + } + + if numDeletedProperties > 0 { + if textArrayArgument, err := pgsql.StringSliceToTextArray(deletedProperties); err != nil { + return err + } else { + arguments = append(arguments, textArrayArgument) + } + + statement = edgePropertySetAndDeleteStatement + } else { + statement = edgePropertySetOnlyStatement + } + } else if numDeletedProperties > 0 { + if textArrayArgument, err := pgsql.StringSliceToTextArray(deletedProperties); err != nil { + return err + } else { + arguments = append(arguments, textArrayArgument) + } + + statement = edgePropertyDeleteOnlyStatement + } + + _, err := s.driver().Exec(s.ctx, statement, append(arguments, relationship.ID)...) + return err +} + +func (s *transaction) Relationships() graph.RelationshipQuery { + return &relationshipQuery{ + liveQuery: newLiveQuery(s.ctx, s, s.schemaManager), + } +} + +func (s *transaction) queryRow(query string, parameters ...any) pgx.Row { + queryArgs := []any{s.queryExecMode, s.queryResultsFormat} + queryArgs = append(queryArgs, parameters...) + + return s.driver().QueryRow(s.ctx, query, queryArgs...) +} + +func (s *transaction) query(query string, parameters map[string]any) (pgx.Rows, error) { + queryArgs := []any{s.queryExecMode, s.queryResultsFormat} + + if parameters != nil || len(parameters) > 0 { + queryArgs = append(queryArgs, pgx.NamedArgs(parameters)) + } + + return s.driver().Query(s.ctx, query, queryArgs...) +} + +func (s *transaction) Query(query string, parameters map[string]any) graph.Result { + if parsedQuery, err := frontend.ParseCypher(frontend.NewContext(), query); err != nil { + return graph.NewErrorResult(err) + } else if translatedParams, err := pgsql.Translate(parsedQuery, s.schemaManager); err != nil { + return graph.NewErrorResult(err) + } else { + var ( + buffer = &bytes.Buffer{} + emitter = pgsql.NewEmitter(false, s.schemaManager) + ) + + for key, value := range parameters { + if _, hasKey := translatedParams[key]; hasKey { + return graph.NewErrorResult(fmt.Errorf("Query specifies a parameter value that is overwritten by translation: %s", key)) + } + + translatedParams[key] = value + } + + if err := emitter.Write(parsedQuery, buffer); err != nil { + return graph.NewErrorResult(err) + } + + return s.Raw(buffer.String(), parameters) + } +} + +func (s *transaction) Raw(query string, parameters map[string]any) graph.Result { + if rows, err := s.query(query, parameters); err != nil { + return graph.NewErrorResult(err) + } else { + return &queryResult{ + rows: rows, + kindMapper: s.schemaManager, + } + } +} + +func (s *transaction) Commit() error { + if s.tx != nil { + return s.tx.Commit(s.ctx) + } + + return nil +} diff --git a/packages/go/dawgs/drivers/pg/types.go b/packages/go/dawgs/drivers/pg/types.go new file mode 100644 index 0000000000..c0b0bfa8fe --- /dev/null +++ b/packages/go/dawgs/drivers/pg/types.go @@ -0,0 +1,195 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package pg + +import ( + "fmt" + "github.com/specterops/bloodhound/dawgs/graph" +) + +type edgeComposite struct { + ID int32 + StartID int32 + EndID int32 + KindID int16 + Properties map[string]any +} + +func castSlice[T any](raw any) ([]T, error) { + if rawSlice, typeOK := raw.([]any); !typeOK { + return nil, fmt.Errorf("expected raw type []any but received %T", raw) + } else { + sliceCopy := make([]T, len(rawSlice)) + + for idx, rawValue := range rawSlice { + if typedValue, typeOK := rawValue.(T); !typeOK { + var empty T + return nil, fmt.Errorf("expected type %T but received %T", empty, rawValue) + } else { + sliceCopy[idx] = typedValue + } + } + + return sliceCopy, nil + } +} + +func castMapValueAsSliceOf[T any](compositeMap map[string]any, key string) ([]T, error) { + if src, hasKey := compositeMap[key]; !hasKey { + return nil, fmt.Errorf("composite map does not contain expected key %s", key) + } else { + return castSlice[T](src) + } +} + +func castAndAssignMapValue[T any](compositeMap map[string]any, key string, dst *T) error { + if src, hasKey := compositeMap[key]; !hasKey { + return fmt.Errorf("composite map does not contain expected key %s", key) + } else if typed, typeOK := src.(T); !typeOK { + var empty T + return fmt.Errorf("expected type %T but received %T", empty, src) + } else { + *dst = typed + } + + return nil +} + +func (s *edgeComposite) TryMap(compositeMap map[string]any) bool { + return s.FromMap(compositeMap) == nil +} + +func (s *edgeComposite) FromMap(compositeMap map[string]any) error { + if err := castAndAssignMapValue(compositeMap, "id", &s.ID); err != nil { + return err + } + + if err := castAndAssignMapValue(compositeMap, "start_id", &s.StartID); err != nil { + return err + } + + if err := castAndAssignMapValue(compositeMap, "end_id", &s.EndID); err != nil { + return err + } + + if err := castAndAssignMapValue(compositeMap, "kind_id", &s.KindID); err != nil { + return err + } + + if err := castAndAssignMapValue(compositeMap, "properties", &s.Properties); err != nil { + return err + } + + return nil +} + +func (s *edgeComposite) ToRelationship(kindMapper KindMapper, relationship *graph.Relationship) error { + if kinds, missingIDs := kindMapper.MapKindIDs(s.KindID); len(missingIDs) > 0 { + return fmt.Errorf("edge references the following unknown kind IDs: %v", missingIDs) + } else { + relationship.Kind = kinds[0] + } + + relationship.ID = graph.ID(s.ID) + relationship.StartID = graph.ID(s.StartID) + relationship.EndID = graph.ID(s.EndID) + relationship.Properties = graph.AsProperties(s.Properties) + + return nil +} + +type nodeComposite struct { + ID int32 + KindIDs []int16 + Properties map[string]any +} + +func (s *nodeComposite) TryMap(compositeMap map[string]any) bool { + return s.FromMap(compositeMap) == nil +} + +func (s *nodeComposite) FromMap(compositeMap map[string]any) error { + if err := castAndAssignMapValue(compositeMap, "id", &s.ID); err != nil { + return err + } + + if kindIDs, err := castMapValueAsSliceOf[int16](compositeMap, "kind_ids"); err != nil { + return err + } else { + s.KindIDs = kindIDs + } + + if err := castAndAssignMapValue(compositeMap, "properties", &s.Properties); err != nil { + return err + } + + return nil +} + +func (s *nodeComposite) ToNode(kindMapper KindMapper, node *graph.Node) error { + if kinds, missingIDs := kindMapper.MapKindIDs(s.KindIDs...); len(missingIDs) > 0 { + return fmt.Errorf("node references the following unknown kind IDs: %v", missingIDs) + } else { + node.Kinds = kinds + } + + node.ID = graph.ID(s.ID) + node.Properties = graph.AsProperties(s.Properties) + + return nil +} + +type pathComposite struct { + Nodes []nodeComposite + Edges []edgeComposite +} + +func (s *pathComposite) TryMap(compositeMap map[string]any) bool { + return s.FromMap(compositeMap) == nil +} + +func (s *pathComposite) FromMap(compositeMap map[string]any) error { + return nil +} + +func (s *pathComposite) ToPath(kindMapper KindMapper, path *graph.Path) error { + path.Nodes = make([]*graph.Node, len(s.Nodes)) + + for idx, pgNode := range s.Nodes { + dawgsNode := &graph.Node{} + + if err := pgNode.ToNode(kindMapper, dawgsNode); err != nil { + return err + } + + path.Nodes[idx] = dawgsNode + } + + path.Edges = make([]*graph.Relationship, len(s.Edges)) + + for idx, pgEdge := range s.Edges { + dawgsRelationship := &graph.Relationship{} + + if err := pgEdge.ToRelationship(kindMapper, dawgsRelationship); err != nil { + return err + } + + path.Edges[idx] = dawgsRelationship + } + + return nil +} diff --git a/packages/go/dawgs/drivers/neo4j/analysis.go b/packages/go/dawgs/drivers/tooling.go similarity index 97% rename from packages/go/dawgs/drivers/neo4j/analysis.go rename to packages/go/dawgs/drivers/tooling.go index 109434e9ec..d20601fb51 100644 --- a/packages/go/dawgs/drivers/neo4j/analysis.go +++ b/packages/go/dawgs/drivers/tooling.go @@ -1,20 +1,20 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 -package neo4j +package drivers import "sync/atomic" diff --git a/packages/go/dawgs/go.mod b/packages/go/dawgs/go.mod index 0c512fe1d6..3629beb259 100644 --- a/packages/go/dawgs/go.mod +++ b/packages/go/dawgs/go.mod @@ -1,27 +1,29 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 module github.com/specterops/bloodhound/dawgs -go 1.20 +go 1.21 require ( github.com/RoaringBitmap/roaring v1.3.0 github.com/axiomhq/hyperloglog v0.0.0-20230201085229-3ddf4bad03dc github.com/gammazero/deque v0.2.1 + github.com/jackc/pgtype v1.14.0 + github.com/jackc/pgx/v5 v5.5.1 github.com/neo4j/neo4j-go-driver/v5 v5.9.0 github.com/specterops/bloodhound/cypher v0.0.0-00010101000000-000000000000 github.com/specterops/bloodhound/log v0.0.0-00010101000000-000000000000 @@ -34,14 +36,23 @@ require ( github.com/bits-and-blooms/bitset v1.8.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect + github.com/jackc/pgconn v1.14.0 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.3.2 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mschoch/smat v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rs/zerolog v1.29.1 // indirect + golang.org/x/crypto v0.10.0 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect + golang.org/x/sync v0.3.0 // indirect golang.org/x/sys v0.9.0 // indirect + golang.org/x/text v0.10.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/packages/go/dawgs/go.sum b/packages/go/dawgs/go.sum index 74a4159f34..e85d2be62b 100644 --- a/packages/go/dawgs/go.sum +++ b/packages/go/dawgs/go.sum @@ -1,7 +1,6 @@ github.com/RoaringBitmap/roaring v1.3.0 h1:aQmu9zQxDU0uhwR8SXOH/OrqEf+X8A0LQmwW3JX8Lcg= github.com/RoaringBitmap/roaring v1.3.0/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= -github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/axiomhq/hyperloglog v0.0.0-20230201085229-3ddf4bad03dc h1:Keo7wQ7UODUaHcEi7ltENhbAK2VgZjfat6mLy03tQzo= github.com/axiomhq/hyperloglog v0.0.0-20230201085229-3ddf4bad03dc/go.mod h1:k08r+Yj1PRAmuayFiRK6MYuR5Ve4IuZtTfxErMIh0+c= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= @@ -18,9 +17,23 @@ github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140/go.mod h1:c9O8+fp github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0= +github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -40,19 +53,22 @@ github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= go.uber.org/mock v0.2.0 h1:TaP3xedm7JaAgScZO7tlvlKrqT0p7I6OsdGB5YNSMDU= go.uber.org/mock v0.2.0/go.mod h1:J0y0rp9L3xiff1+ZBfKxlC1fz2+aO16tw0tsDOixfuM= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= -golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/packages/go/dawgs/graph/graph.go b/packages/go/dawgs/graph/graph.go index 7311579d3f..1a6446bcbf 100644 --- a/packages/go/dawgs/graph/graph.go +++ b/packages/go/dawgs/graph/graph.go @@ -21,7 +21,9 @@ package graph import ( "context" "errors" + "slices" "strconv" + "strings" "time" "github.com/specterops/bloodhound/dawgs/util/size" @@ -92,6 +94,19 @@ func (s Direction) PickReverse(relationship *Relationship) (ID, error) { return s.PickReverseID(relationship.StartID, relationship.EndID) } +func (s Direction) String() string { + switch s { + case DirectionInbound: + return "inbound" + case DirectionOutbound: + return "outbound" + case DirectionBoth: + return "both" + default: + return "invalid" + } +} + // ID is a 32-bit database Entity identifier type. Negative ID value associations in DAWGS drivers are not recommended // and should not be considered during driver implementation. type ID uint32 @@ -175,6 +190,22 @@ type NodeUpdate struct { IdentityProperties []string } +func (s NodeUpdate) Key() (string, error) { + key := strings.Builder{} + + slices.Sort(s.IdentityProperties) + + for _, identityProperty := range s.IdentityProperties { + if propertyValue, err := s.Node.Properties.Get(identityProperty).String(); err != nil { + return "", err + } else { + key.WriteString(propertyValue) + } + } + + return key.String(), nil +} + type RelationshipUpdate struct { Relationship *Relationship IdentityProperties []string @@ -186,6 +217,46 @@ type RelationshipUpdate struct { EndIdentityProperties []string } +func (s RelationshipUpdate) Key() (string, error) { + var ( + key = strings.Builder{} + startNodeUpdate = NodeUpdate{ + Node: s.Start, + IdentityKind: s.StartIdentityKind, + IdentityProperties: s.StartIdentityProperties, + } + + endNodeUpdate = NodeUpdate{ + Node: s.End, + IdentityKind: s.EndIdentityKind, + IdentityProperties: s.EndIdentityProperties, + } + ) + + key.WriteString(s.Relationship.Kind.String()) + + if startKey, err := startNodeUpdate.Key(); err != nil { + return "", err + } else if endKey, err := endNodeUpdate.Key(); err != nil { + return "", err + } else { + key.WriteString(startKey) + + slices.Sort(s.IdentityProperties) + + for _, identityProperty := range s.IdentityProperties { + if propertyValue, err := s.Relationship.Properties.Get(identityProperty).String(); err != nil { + return "", err + } else { + key.WriteString(propertyValue) + } + } + + key.WriteString(endKey) + return key.String(), nil + } +} + func (s RelationshipUpdate) IdentityPropertiesMap() map[string]any { identityPropertiesMap := make(map[string]any, len(s.IdentityProperties)) @@ -217,8 +288,12 @@ func (s RelationshipUpdate) EndIdentityPropertiesMap() map[string]any { } type Batch interface { + // WithGraph scopes the transaction to a specific graph. If the driver for the transaction does not support + // multiple graphs the resulting transaction will target the default graph instead and this call becomes a no-op. + WithGraph(graphSchema Graph) Batch + // CreateNode creates a new Node in the database and returns the creation as a NodeResult. - CreateNode(properties *Properties, kinds ...Kind) error + CreateNode(node *Node) error // DeleteNode deletes a node by the given ID. DeleteNode(id ID) error @@ -234,12 +309,13 @@ type Batch interface { // exist, created. UpdateNodeBy(update NodeUpdate) error - // CreateRelationship creates a new Relationship from the start Node to the end Node with the given Kind and - // Properties and returns the creation as a RelationshipResult. - CreateRelationship(startNode, endNode *Node, kind Kind, properties *Properties) error + // TODO: Existing batch logic expects this to perform an upsert on conficts with (start_id, end_id, kind). This is incorrect and should be refactored + CreateRelationship(relationship *Relationship) error // CreateRelationshipByIDs creates a new Relationship from the start Node to the end Node with the given Kind and // Properties and returns the creation as a RelationshipResult. + // + // Deprecated: Use CreateRelationship CreateRelationshipByIDs(startNodeID, endNodeID ID, kind Kind, properties *Properties) error // DeleteRelationship deletes a relationship by the given ID. @@ -257,6 +333,10 @@ type Batch interface { // Transaction is an interface that contains all operations that may be executed against a DAWGS driver. DAWGS drivers are // expected to support all Transaction operations in-transaction. type Transaction interface { + // WithGraph scopes the transaction to a specific graph. If the driver for the transaction does not support + // multiple graphs the resulting transaction will target the default graph instead and this call becomes a no-op. + WithGraph(graphSchema Graph) Transaction + // CreateNode creates a new Node in the database and returns the creation as a NodeResult. CreateNode(properties *Properties, kinds ...Kind) (*Node, error) @@ -264,17 +344,9 @@ type Transaction interface { // entries in the database. Use CreateNode first to create a new Node. UpdateNode(node *Node) error - // UpdateNodeBy updates a Node by attempting to write a valid merge statement for the criteria in the given - // NodeUpdate struct. - UpdateNodeBy(update NodeUpdate) error - // Nodes creates a new NodeQuery and returns it. Nodes() NodeQuery - // CreateRelationship creates a new Relationship from the start Node to the end Node with the given Kind and - // Properties and returns the creation as a RelationshipResult. - CreateRelationship(startNode, endNode *Node, kind Kind, properties *Properties) (*Relationship, error) - // CreateRelationshipByIDs creates a new Relationship from the start Node to the end Node with the given Kind and // Properties and returns the creation as a RelationshipResult. CreateRelationshipByIDs(startNodeID, endNodeID ID, kind Kind, properties *Properties) (*Relationship, error) @@ -284,15 +356,14 @@ type Transaction interface { // Relationship. UpdateRelationship(relationship *Relationship) error - // UpdateRelationshipBy updates a Relationship by attempting to write a valid merge statement for the criteria in - // the given RelationshipUpdate struct. - UpdateRelationshipBy(update RelationshipUpdate) error - // Relationships creates a new RelationshipQuery and returns it. Relationships() RelationshipQuery - // Run allows a user to pass statements directly to the database. - Run(query string, parameters map[string]any) Result + // Raw allows a user to pass raw queries directly to the database without translation. + Raw(query string, parameters map[string]any) Result + + // Query allows a user to execute a given cypher query that will be translated to the target database. + Query(query string, parameters map[string]any) Result // Commit calls to commit this transaction right away. Commit() error @@ -311,7 +382,8 @@ type BatchDelegate func(batch Batch) error // TransactionConfig is a generic configuration that may apply to all supported databases. type TransactionConfig struct { - Timeout time.Duration + Timeout time.Duration + DriverConfig any } // TransactionOption is a function that represents a configuration setting for the underlying database transaction. @@ -339,16 +411,13 @@ type Database interface { // transaction. BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error - // AssertSchema will apply the given schema model to the underlying database. - AssertSchema(ctx context.Context, schema *Schema) error - - // FetchSchema will pull the schema of the underlying database and marshal it into the DAWGS schema model. - FetchSchema(ctx context.Context) (*Schema, error) + // AssertSchema will apply the given schema to the underlying database. + AssertSchema(ctx context.Context, dbSchema Schema) error // Run allows a user to pass statements directly to the database. Since results may rely on a transactional context // only an error is returned from this function Run(ctx context.Context, query string, parameters map[string]any) error // Close closes the database context and releases any pooled resources held by the instance. - Close() error + Close(ctx context.Context) error } diff --git a/packages/go/dawgs/graph/kind.go b/packages/go/dawgs/graph/kind.go index 85aedf1973..c4c8a08026 100644 --- a/packages/go/dawgs/graph/kind.go +++ b/packages/go/dawgs/graph/kind.go @@ -17,7 +17,6 @@ package graph import ( - "context" "sync" "unsafe" @@ -178,10 +177,3 @@ func (s stringKind) Is(other ...Kind) bool { return false } - -type KindMapper interface { - GetKind(kindIntKey int32) (Kind, error) - GetKinds(kindIntKeys []int32) (Kinds, error) - GetKey(ctx context.Context, kindStr string) (int32, error) - GetKeys(ctx context.Context, kindStr []string) ([]int32, error) -} diff --git a/packages/go/dawgs/graph/mapper.go b/packages/go/dawgs/graph/mapper.go new file mode 100644 index 0000000000..acb87a4a5c --- /dev/null +++ b/packages/go/dawgs/graph/mapper.go @@ -0,0 +1,483 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package graph + +import ( + "fmt" + "strconv" + "time" +) + +type numeric interface { + uint | uint8 | uint16 | uint32 | uint64 | int | int8 | int16 | int32 | int64 | float32 | float64 | ID +} + +func AsNumeric[T numeric](rawValue any) (T, error) { + var empty T + + switch typedValue := rawValue.(type) { + case uint: + return T(typedValue), nil + case uint8: + return T(typedValue), nil + case uint16: + return T(typedValue), nil + case uint32: + return T(typedValue), nil + case uint64: + return T(typedValue), nil + case int: + return T(typedValue), nil + case int8: + return T(typedValue), nil + case int16: + return T(typedValue), nil + case int32: + return T(typedValue), nil + case int64: + return T(typedValue), nil + case float32: + return T(typedValue), nil + case float64: + return T(typedValue), nil + case string: + if parsedInt, err := strconv.ParseInt(typedValue, 10, 64); err != nil { + if parsedFloat, err := strconv.ParseFloat(typedValue, 64); err != nil { + return empty, fmt.Errorf("unable to parse numeric value from raw value %s", typedValue) + } else { + return T(parsedFloat), nil + } + } else { + return T(parsedInt), nil + } + + default: + return empty, fmt.Errorf("unable to convert raw value %T as numeric", rawValue) + } +} + +func castNumericSlice[R numeric, T any](src []T) ([]R, error) { + dst := make([]R, len(src)) + + for idx, srcValue := range src { + if numericValue, err := AsNumeric[R](srcValue); err != nil { + return nil, err + } else { + dst[idx] = numericValue + } + } + + return dst, nil +} + +func AsNumericSlice[T numeric](rawValue any) ([]T, error) { + var numericSlice []T + + switch typedValue := rawValue.(type) { + case []any: + return castNumericSlice[T](typedValue) + case []uint: + return castNumericSlice[T](typedValue) + case []uint8: + return castNumericSlice[T](typedValue) + case []uint16: + return castNumericSlice[T](typedValue) + case []uint32: + return castNumericSlice[T](typedValue) + case []uint64: + return castNumericSlice[T](typedValue) + case []int: + return castNumericSlice[T](typedValue) + case []int8: + return castNumericSlice[T](typedValue) + case []int16: + return castNumericSlice[T](typedValue) + case []int32: + return castNumericSlice[T](typedValue) + case []int64: + return castNumericSlice[T](typedValue) + case []float32: + return castNumericSlice[T](typedValue) + case []float64: + return castNumericSlice[T](typedValue) + default: + return nil, fmt.Errorf("unable to convert raw value %T as a numeric slice", rawValue) + } + + return numericSlice, nil +} + +func AsKinds(rawValue any) (Kinds, error) { + if stringValues, err := SliceOf[string](rawValue); err != nil { + return nil, err + } else { + return StringsToKinds(stringValues), nil + } +} + +func AsTime(value any) (time.Time, error) { + switch typedValue := value.(type) { + case string: + if parsedTime, err := time.Parse(time.RFC3339Nano, typedValue); err != nil { + return time.Time{}, err + } else { + return parsedTime, nil + } + + case float64: + return time.Unix(int64(typedValue), 0), nil + + case int64: + return time.Unix(typedValue, 0), nil + + case time.Time: + return typedValue, nil + + default: + return time.Time{}, fmt.Errorf("unexecpted type %T will not negotiate to time.Time", value) + } +} + +func defaultMapValue(rawValue, target any) (bool, error) { + switch typedTarget := target.(type) { + case *uint: + if value, err := AsNumeric[uint](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]uint: + if value, err := AsNumericSlice[uint](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *uint8: + if value, err := AsNumeric[uint8](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]uint8: + if value, err := AsNumericSlice[uint8](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *uint16: + if value, err := AsNumeric[uint16](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]uint16: + if value, err := AsNumericSlice[uint16](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *uint32: + if value, err := AsNumeric[uint32](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]uint32: + if value, err := AsNumericSlice[uint32](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *uint64: + if value, err := AsNumeric[uint64](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]uint64: + if value, err := AsNumericSlice[uint64](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *int: + if value, err := AsNumeric[int](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]int: + if value, err := AsNumericSlice[int](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *int8: + if value, err := AsNumeric[int8](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]int8: + if value, err := AsNumericSlice[int8](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *int16: + if value, err := AsNumeric[int16](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]int16: + if value, err := AsNumericSlice[int16](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *int32: + if value, err := AsNumeric[int32](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]int32: + if value, err := AsNumericSlice[int32](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *int64: + if value, err := AsNumeric[int64](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]int64: + if value, err := AsNumericSlice[int64](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *ID: + if value, err := AsNumeric[ID](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]ID: + if value, err := AsNumericSlice[ID](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *float32: + if value, err := AsNumeric[float32](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]float32: + if value, err := AsNumericSlice[float32](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *float64: + if value, err := AsNumeric[float64](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *[]float64: + if value, err := AsNumericSlice[float64](rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + case *bool: + if value, typeOK := rawValue.(bool); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to bool", value) + } else { + *typedTarget = value + } + + case *Kind: + if strValue, typeOK := rawValue.(string); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to string", rawValue) + } else { + *typedTarget = StringKind(strValue) + } + + case *string: + if value, typeOK := rawValue.(string); !typeOK { + return false, fmt.Errorf("unexecpted type %T will not negotiate to string", rawValue) + } else { + *typedTarget = value + } + + case *[]Kind: + if kindValues, err := AsKinds(rawValue); err != nil { + return false, err + } else { + *typedTarget = kindValues + } + + case *Kinds: + if kindValues, err := AsKinds(rawValue); err != nil { + return false, err + } else { + *typedTarget = kindValues + } + + case *[]string: + if stringValues, err := SliceOf[string](rawValue); err != nil { + return false, err + } else { + *typedTarget = stringValues + } + + case *time.Time: + if value, err := AsTime(rawValue); err != nil { + return false, err + } else { + *typedTarget = value + } + + default: + return false, nil + } + + return true, nil +} + +type MapFunc func(rawValue, target any) (bool, error) + +type valueMapper struct { + mapperFuncs []MapFunc + values []any + idx int +} + +func NewValueMapper(values []any, mappers ...MapFunc) ValueMapper { + return &valueMapper{ + mapperFuncs: append(mappers, defaultMapValue), + values: values, + idx: 0, + } +} + +func (s *valueMapper) Next() (any, error) { + if s.idx >= len(s.values) { + return nil, fmt.Errorf("attempting to get more values than returned - saw %d but wanted %d", len(s.values), s.idx+1) + } + + nextValue := s.values[s.idx] + s.idx++ + + return nextValue, nil +} + +func (s *valueMapper) Map(target any) error { + if rawValue, err := s.Next(); err != nil { + return err + } else { + for _, mapperFunc := range s.mapperFuncs { + if mapped, err := mapperFunc(rawValue, target); err != nil { + return err + } else if mapped { + return nil + } + } + } + + return fmt.Errorf("unsupported scan type %T", target) +} + +func SliceOf[T any](raw any) ([]T, error) { + if slice, typeOK := raw.([]any); !typeOK { + return nil, fmt.Errorf("expected []any slice but received %T", raw) + } else { + sliceCopy := make([]T, len(slice)) + + for idx, sliceValue := range slice { + if typedSliceValue, typeOK := sliceValue.(T); !typeOK { + var empty T + return nil, fmt.Errorf("expected type %T but received %T", empty, sliceValue) + } else { + sliceCopy[idx] = typedSliceValue + } + } + + return sliceCopy, nil + } +} + +func (s *valueMapper) MapOptions(targets ...any) (any, error) { + if rawValue, err := s.Next(); err != nil { + return nil, err + } else { + for _, target := range targets { + for _, mapperFunc := range s.mapperFuncs { + if mapped, _ := mapperFunc(rawValue, target); mapped { + return target, nil + } + } + } + + return nil, fmt.Errorf("no matching target given for type: %T", rawValue) + } +} + +func (s *valueMapper) Scan(targets ...any) error { + for idx, mapValue := range targets { + if err := s.Map(mapValue); err != nil { + return err + } else { + targets[idx] = mapValue + } + } + + return nil +} diff --git a/packages/go/dawgs/graph/mocks/graph.go b/packages/go/dawgs/graph/mocks/graph.go index 0dfb17b396..274802fb5f 100644 --- a/packages/go/dawgs/graph/mocks/graph.go +++ b/packages/go/dawgs/graph/mocks/graph.go @@ -334,36 +334,31 @@ func (mr *MockBatchMockRecorder) Commit() *gomock.Call { } // CreateNode mocks base method. -func (m *MockBatch) CreateNode(properties *graph.Properties, kinds ...graph.Kind) error { +func (m *MockBatch) CreateNode(node *graph.Node) error { m.ctrl.T.Helper() - varargs := []interface{}{properties} - for _, a := range kinds { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "CreateNode", varargs...) + ret := m.ctrl.Call(m, "CreateNode", node) ret0, _ := ret[0].(error) return ret0 } // CreateNode indicates an expected call of CreateNode. -func (mr *MockBatchMockRecorder) CreateNode(properties interface{}, kinds ...interface{}) *gomock.Call { +func (mr *MockBatchMockRecorder) CreateNode(node interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{properties}, kinds...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNode", reflect.TypeOf((*MockBatch)(nil).CreateNode), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNode", reflect.TypeOf((*MockBatch)(nil).CreateNode), node) } // CreateRelationship mocks base method. -func (m *MockBatch) CreateRelationship(startNode, endNode *graph.Node, kind graph.Kind, properties *graph.Properties) error { +func (m *MockBatch) CreateRelationship(relationship *graph.Relationship) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRelationship", startNode, endNode, kind, properties) + ret := m.ctrl.Call(m, "CreateRelationship", relationship) ret0, _ := ret[0].(error) return ret0 } // CreateRelationship indicates an expected call of CreateRelationship. -func (mr *MockBatchMockRecorder) CreateRelationship(startNode, endNode, kind, properties interface{}) *gomock.Call { +func (mr *MockBatchMockRecorder) CreateRelationship(relationship interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRelationship", reflect.TypeOf((*MockBatch)(nil).CreateRelationship), startNode, endNode, kind, properties) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRelationship", reflect.TypeOf((*MockBatch)(nil).CreateRelationship), relationship) } // CreateRelationshipByIDs mocks base method. @@ -464,6 +459,20 @@ func (mr *MockBatchMockRecorder) UpdateRelationshipBy(update interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRelationshipBy", reflect.TypeOf((*MockBatch)(nil).UpdateRelationshipBy), update) } +// WithGraph mocks base method. +func (m *MockBatch) WithGraph(graphSchema graph.Graph) graph.Batch { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithGraph", graphSchema) + ret0, _ := ret[0].(graph.Batch) + return ret0 +} + +// WithGraph indicates an expected call of WithGraph. +func (mr *MockBatchMockRecorder) WithGraph(graphSchema interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithGraph", reflect.TypeOf((*MockBatch)(nil).WithGraph), graphSchema) +} + // MockTransaction is a mock of Transaction interface. type MockTransaction struct { ctrl *gomock.Controller @@ -521,21 +530,6 @@ func (mr *MockTransactionMockRecorder) CreateNode(properties interface{}, kinds return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNode", reflect.TypeOf((*MockTransaction)(nil).CreateNode), varargs...) } -// CreateRelationship mocks base method. -func (m *MockTransaction) CreateRelationship(startNode, endNode *graph.Node, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRelationship", startNode, endNode, kind, properties) - ret0, _ := ret[0].(*graph.Relationship) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateRelationship indicates an expected call of CreateRelationship. -func (mr *MockTransactionMockRecorder) CreateRelationship(startNode, endNode, kind, properties interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRelationship", reflect.TypeOf((*MockTransaction)(nil).CreateRelationship), startNode, endNode, kind, properties) -} - // CreateRelationshipByIDs mocks base method. func (m *MockTransaction) CreateRelationshipByIDs(startNodeID, endNodeID graph.ID, kind graph.Kind, properties *graph.Properties) (*graph.Relationship, error) { m.ctrl.T.Helper() @@ -565,32 +559,46 @@ func (mr *MockTransactionMockRecorder) Nodes() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Nodes", reflect.TypeOf((*MockTransaction)(nil).Nodes)) } -// Relationships mocks base method. -func (m *MockTransaction) Relationships() graph.RelationshipQuery { +// Query mocks base method. +func (m *MockTransaction) Query(query string, parameters map[string]any) graph.Result { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Relationships") - ret0, _ := ret[0].(graph.RelationshipQuery) + ret := m.ctrl.Call(m, "Query", query, parameters) + ret0, _ := ret[0].(graph.Result) return ret0 } -// Relationships indicates an expected call of Relationships. -func (mr *MockTransactionMockRecorder) Relationships() *gomock.Call { +// Query indicates an expected call of Query. +func (mr *MockTransactionMockRecorder) Query(query, parameters interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relationships", reflect.TypeOf((*MockTransaction)(nil).Relationships)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), query, parameters) } -// Run mocks base method. -func (m *MockTransaction) Run(query string, parameters map[string]any) graph.Result { +// Raw mocks base method. +func (m *MockTransaction) Raw(query string, parameters map[string]any) graph.Result { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Run", query, parameters) + ret := m.ctrl.Call(m, "Raw", query, parameters) ret0, _ := ret[0].(graph.Result) return ret0 } -// Run indicates an expected call of Run. -func (mr *MockTransactionMockRecorder) Run(query, parameters interface{}) *gomock.Call { +// Raw indicates an expected call of Raw. +func (mr *MockTransactionMockRecorder) Raw(query, parameters interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockTransaction)(nil).Run), query, parameters) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Raw", reflect.TypeOf((*MockTransaction)(nil).Raw), query, parameters) +} + +// Relationships mocks base method. +func (m *MockTransaction) Relationships() graph.RelationshipQuery { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Relationships") + ret0, _ := ret[0].(graph.RelationshipQuery) + return ret0 +} + +// Relationships indicates an expected call of Relationships. +func (mr *MockTransactionMockRecorder) Relationships() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Relationships", reflect.TypeOf((*MockTransaction)(nil).Relationships)) } // TraversalMemoryLimit mocks base method. @@ -621,20 +629,6 @@ func (mr *MockTransactionMockRecorder) UpdateNode(node interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateNode", reflect.TypeOf((*MockTransaction)(nil).UpdateNode), node) } -// UpdateNodeBy mocks base method. -func (m *MockTransaction) UpdateNodeBy(update graph.NodeUpdate) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateNodeBy", update) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateNodeBy indicates an expected call of UpdateNodeBy. -func (mr *MockTransactionMockRecorder) UpdateNodeBy(update interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateNodeBy", reflect.TypeOf((*MockTransaction)(nil).UpdateNodeBy), update) -} - // UpdateRelationship mocks base method. func (m *MockTransaction) UpdateRelationship(relationship *graph.Relationship) error { m.ctrl.T.Helper() @@ -649,18 +643,18 @@ func (mr *MockTransactionMockRecorder) UpdateRelationship(relationship interface return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRelationship", reflect.TypeOf((*MockTransaction)(nil).UpdateRelationship), relationship) } -// UpdateRelationshipBy mocks base method. -func (m *MockTransaction) UpdateRelationshipBy(update graph.RelationshipUpdate) error { +// WithGraph mocks base method. +func (m *MockTransaction) WithGraph(graphSchema graph.Graph) graph.Transaction { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateRelationshipBy", update) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "WithGraph", graphSchema) + ret0, _ := ret[0].(graph.Transaction) return ret0 } -// UpdateRelationshipBy indicates an expected call of UpdateRelationshipBy. -func (mr *MockTransactionMockRecorder) UpdateRelationshipBy(update interface{}) *gomock.Call { +// WithGraph indicates an expected call of WithGraph. +func (mr *MockTransactionMockRecorder) WithGraph(graphSchema interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRelationshipBy", reflect.TypeOf((*MockTransaction)(nil).UpdateRelationshipBy), update) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithGraph", reflect.TypeOf((*MockTransaction)(nil).WithGraph), graphSchema) } // MockDatabase is a mock of Database interface. @@ -687,17 +681,17 @@ func (m *MockDatabase) EXPECT() *MockDatabaseMockRecorder { } // AssertSchema mocks base method. -func (m *MockDatabase) AssertSchema(ctx context.Context, schema *graph.Schema) error { +func (m *MockDatabase) AssertSchema(ctx context.Context, dbSchema graph.Schema) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AssertSchema", ctx, schema) + ret := m.ctrl.Call(m, "AssertSchema", ctx, dbSchema) ret0, _ := ret[0].(error) return ret0 } // AssertSchema indicates an expected call of AssertSchema. -func (mr *MockDatabaseMockRecorder) AssertSchema(ctx, schema interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) AssertSchema(ctx, dbSchema interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertSchema", reflect.TypeOf((*MockDatabase)(nil).AssertSchema), ctx, schema) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AssertSchema", reflect.TypeOf((*MockDatabase)(nil).AssertSchema), ctx, dbSchema) } // BatchOperation mocks base method. @@ -715,32 +709,17 @@ func (mr *MockDatabaseMockRecorder) BatchOperation(ctx, batchDelegate interface{ } // Close mocks base method. -func (m *MockDatabase) Close() error { +func (m *MockDatabase) Close(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") + ret := m.ctrl.Call(m, "Close", ctx) ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. -func (mr *MockDatabaseMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDatabase)(nil).Close)) -} - -// FetchSchema mocks base method. -func (m *MockDatabase) FetchSchema(ctx context.Context) (*graph.Schema, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FetchSchema", ctx) - ret0, _ := ret[0].(*graph.Schema) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// FetchSchema indicates an expected call of FetchSchema. -func (mr *MockDatabaseMockRecorder) FetchSchema(ctx interface{}) *gomock.Call { +func (mr *MockDatabaseMockRecorder) Close(ctx interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchSchema", reflect.TypeOf((*MockDatabase)(nil).FetchSchema), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDatabase)(nil).Close), ctx) } // ReadTransaction mocks base method. diff --git a/packages/go/dawgs/graph/mocks/query.go b/packages/go/dawgs/graph/mocks/query.go index f8b41600f7..ff8298adbe 100644 --- a/packages/go/dawgs/graph/mocks/query.go +++ b/packages/go/dawgs/graph/mocks/query.go @@ -172,11 +172,12 @@ func (mr *MockScannerMockRecorder) Scan(targets ...interface{}) *gomock.Call { } // Values mocks base method. -func (m *MockScanner) Values() graph.ValueMapper { +func (m *MockScanner) Values() (graph.ValueMapper, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Values") ret0, _ := ret[0].(graph.ValueMapper) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Values indicates an expected call of Values. @@ -267,11 +268,12 @@ func (mr *MockResultMockRecorder) Scan(targets ...interface{}) *gomock.Call { } // Values mocks base method. -func (m *MockResult) Values() graph.ValueMapper { +func (m *MockResult) Values() (graph.ValueMapper, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Values") ret0, _ := ret[0].(graph.ValueMapper) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // Values indicates an expected call of Values. @@ -318,21 +320,6 @@ func (mr *MockNodeQueryMockRecorder) Count() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockNodeQuery)(nil).Count)) } -// Debug mocks base method. -func (m *MockNodeQuery) Debug() (string, map[string]any) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Debug") - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(map[string]any) - return ret0, ret1 -} - -// Debug indicates an expected call of Debug. -func (mr *MockNodeQueryMockRecorder) Debug() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockNodeQuery)(nil).Debug)) -} - // Delete mocks base method. func (m *MockNodeQuery) Delete() error { m.ctrl.T.Helper() @@ -347,25 +334,6 @@ func (mr *MockNodeQueryMockRecorder) Delete() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockNodeQuery)(nil).Delete)) } -// Execute mocks base method. -func (m *MockNodeQuery) Execute(delegate func(graph.Result) error, finalCriteria ...graph.Criteria) error { - m.ctrl.T.Helper() - varargs := []interface{}{delegate} - for _, a := range finalCriteria { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Execute", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// Execute indicates an expected call of Execute. -func (mr *MockNodeQueryMockRecorder) Execute(delegate interface{}, finalCriteria ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{delegate}, finalCriteria...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockNodeQuery)(nil).Execute), varargs...) -} - // Fetch mocks base method. func (m *MockNodeQuery) Fetch(delegate func(graph.Cursor[*graph.Node]) error) error { m.ctrl.T.Helper() @@ -497,6 +465,25 @@ func (mr *MockNodeQueryMockRecorder) OrderBy(criteria ...interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OrderBy", reflect.TypeOf((*MockNodeQuery)(nil).OrderBy), criteria...) } +// Query mocks base method. +func (m *MockNodeQuery) Query(delegate func(graph.Result) error, finalCriteria ...graph.Criteria) error { + m.ctrl.T.Helper() + varargs := []interface{}{delegate} + for _, a := range finalCriteria { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Query indicates an expected call of Query. +func (mr *MockNodeQueryMockRecorder) Query(delegate interface{}, finalCriteria ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{delegate}, finalCriteria...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockNodeQuery)(nil).Query), varargs...) +} + // Update mocks base method. func (m *MockNodeQuery) Update(properties *graph.Properties) error { m.ctrl.T.Helper() @@ -549,21 +536,6 @@ func (mr *MockRelationshipQueryMockRecorder) Count() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockRelationshipQuery)(nil).Count)) } -// Debug mocks base method. -func (m *MockRelationshipQuery) Debug() (string, map[string]any) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Debug") - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(map[string]any) - return ret0, ret1 -} - -// Debug indicates an expected call of Debug. -func (mr *MockRelationshipQueryMockRecorder) Debug() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockRelationshipQuery)(nil).Debug)) -} - // Delete mocks base method. func (m *MockRelationshipQuery) Delete() error { m.ctrl.T.Helper() @@ -578,25 +550,6 @@ func (mr *MockRelationshipQueryMockRecorder) Delete() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockRelationshipQuery)(nil).Delete)) } -// Execute mocks base method. -func (m *MockRelationshipQuery) Execute(delegate func(graph.Result) error, finalCriteria ...graph.Criteria) error { - m.ctrl.T.Helper() - varargs := []interface{}{delegate} - for _, a := range finalCriteria { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Execute", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// Execute indicates an expected call of Execute. -func (mr *MockRelationshipQueryMockRecorder) Execute(delegate interface{}, finalCriteria ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{delegate}, finalCriteria...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Execute", reflect.TypeOf((*MockRelationshipQuery)(nil).Execute), varargs...) -} - // Fetch mocks base method. func (m *MockRelationshipQuery) Fetch(delegate func(graph.Cursor[*graph.Relationship]) error) error { m.ctrl.T.Helper() @@ -770,6 +723,25 @@ func (mr *MockRelationshipQueryMockRecorder) OrderBy(criteria ...interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OrderBy", reflect.TypeOf((*MockRelationshipQuery)(nil).OrderBy), criteria...) } +// Query mocks base method. +func (m *MockRelationshipQuery) Query(delegate func(graph.Result) error, finalCriteria ...graph.Criteria) error { + m.ctrl.T.Helper() + varargs := []interface{}{delegate} + for _, a := range finalCriteria { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Query indicates an expected call of Query. +func (mr *MockRelationshipQueryMockRecorder) Query(delegate interface{}, finalCriteria ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{delegate}, finalCriteria...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockRelationshipQuery)(nil).Query), varargs...) +} + // Update mocks base method. func (m *MockRelationshipQuery) Update(properties *graph.Properties) error { m.ctrl.T.Helper() diff --git a/packages/go/dawgs/graph/node.go b/packages/go/dawgs/graph/node.go index f0fd0d7706..afdcd23e6f 100644 --- a/packages/go/dawgs/graph/node.go +++ b/packages/go/dawgs/graph/node.go @@ -58,6 +58,24 @@ type Node struct { Properties *Properties `json:"properties"` } +func (s *Node) Merge(other *Node) { + s.Kinds = s.Kinds.Add(other.Kinds...) + + for _, otherKind := range other.AddedKinds { + s.DeletedKinds = s.DeletedKinds.Remove(otherKind) + } + + for _, otherKind := range other.DeletedKinds { + s.Kinds = s.Kinds.Remove(otherKind) + s.AddedKinds = s.AddedKinds.Remove(otherKind) + } + + s.AddedKinds = s.AddedKinds.Add(other.AddedKinds...) + s.DeletedKinds = s.DeletedKinds.Add(other.DeletedKinds...) + + s.Properties.Merge(other.Properties) +} + func (s *Node) SizeOf() size.Size { nodeSize := size.Of(s) + s.Kinds.SizeOf() diff --git a/packages/go/dawgs/graph/properties.go b/packages/go/dawgs/graph/properties.go index d6b2f17b0f..99dddeed32 100644 --- a/packages/go/dawgs/graph/properties.go +++ b/packages/go/dawgs/graph/properties.go @@ -149,27 +149,15 @@ func (s safePropertyValue) Int() (int, error) { if rawValue, err := s.getValue(); err != nil { return 0, err } else { - switch typedValue := rawValue.(type) { - case int: - return typedValue, err - - case int64: - return int(typedValue), nil - } - - return 0, formatPropertyTypeError("int", rawValue) + return AsNumeric[int](rawValue) } } func (s safePropertyValue) Int64() (int64, error) { if rawValue, err := s.getValue(); err != nil { - return 0, err - } else if typedValue, typeOK := rawValue.(int64); !typeOK { - err := formatPropertyTypeError("int64", rawValue) - return 0, err } else { - return typedValue, nil + return AsNumeric[int64](rawValue) } } @@ -177,26 +165,15 @@ func (s safePropertyValue) Uint64() (uint64, error) { if rawValue, err := s.getValue(); err != nil { return 0, err } else { - switch typedValue := rawValue.(type) { - case uint64: - return typedValue, nil - case int64: - return uint64(typedValue), nil - default: - return 0, formatPropertyTypeError("uint64", rawValue) - } + return AsNumeric[uint64](rawValue) } } func (s safePropertyValue) Float64() (float64, error) { if rawValue, err := s.getValue(); err != nil { - return 0, err - } else if typedValue, typeOK := rawValue.(float64); !typeOK { - err := formatPropertyTypeError("float64", rawValue) - return 0, err } else { - return typedValue, nil + return AsNumeric[float64](rawValue) } } @@ -269,6 +246,33 @@ type Properties struct { Modified map[string]struct{} `json:"modified"` } +func (s *Properties) Merge(other *Properties) { + for otherKey, otherValue := range other.Map { + s.Map[otherKey] = otherValue + } + + for otherModifiedKey := range other.Modified { + s.Modified[otherModifiedKey] = struct{}{} + + delete(s.Deleted, otherModifiedKey) + } + + for otherDeletedKey := range other.Deleted { + s.Deleted[otherDeletedKey] = struct{}{} + + delete(s.Map, otherDeletedKey) + delete(s.Modified, otherDeletedKey) + } +} + +func (s *Properties) MapOrEmpty() map[string]any { + if s == nil || s.Map == nil { + return map[string]any{} + } + + return s.Map +} + func (s *Properties) SizeOf() size.Size { instanceSize := size.Of(*s) @@ -433,10 +437,19 @@ func (s *Properties) Delete(key string) *Properties { return s } +// TODO: This function does not correctly communicate that it is lazily instantiated func NewProperties() *Properties { return &Properties{} } +func NewPropertiesRed() *Properties { + return &Properties{ + Map: map[string]any{}, + Modified: make(map[string]struct{}), + Deleted: make(map[string]struct{}), + } +} + type PropertyMap map[String]any func symbolMapToStringMap(props map[String]any) map[string]any { diff --git a/packages/go/dawgs/graph/properties_test.go b/packages/go/dawgs/graph/properties_test.go index ee369f1677..653cb2a0ff 100644 --- a/packages/go/dawgs/graph/properties_test.go +++ b/packages/go/dawgs/graph/properties_test.go @@ -17,11 +17,10 @@ package graph_test import ( - "strconv" - "testing" - "github.com/specterops/bloodhound/dawgs/graph" "github.com/stretchr/testify/require" + "strconv" + "testing" ) func TestNewProperties(t *testing.T) { diff --git a/packages/go/dawgs/graph/query.go b/packages/go/dawgs/graph/query.go index 4e5fb33f24..e5282eb74f 100644 --- a/packages/go/dawgs/graph/query.go +++ b/packages/go/dawgs/graph/query.go @@ -27,7 +27,7 @@ type ValueMapper interface { type Scanner interface { Next() bool - Values() ValueMapper + Values() (ValueMapper, error) Scan(targets ...any) error } @@ -38,6 +38,35 @@ type Result interface { Close() } +type ErrorResult struct { + err error +} + +func (s ErrorResult) Next() bool { + return false +} + +func (s ErrorResult) Values() (ValueMapper, error) { + return nil, s.err +} + +func (s ErrorResult) Scan(targets ...any) error { + return s.err +} + +func (s ErrorResult) Error() error { + return s.err +} + +func (s ErrorResult) Close() { +} + +func NewErrorResult(err error) Result { + return ErrorResult{ + err: err, + } +} + // Criteria is a top-level alias for communicating structured query filter criteria to a query generator. type Criteria any @@ -53,8 +82,8 @@ type NodeQuery interface { // Filterf applies the given criteria provider function to this query. Filterf(criteriaDelegate CriteriaProvider) NodeQuery - // Execute completes the query and hands the raw result to the given delegate for unmarshalling - Execute(delegate func(results Result) error, finalCriteria ...Criteria) error + // Query completes the query and hands the raw result to the given delegate for unmarshalling + Query(delegate func(results Result) error, finalCriteria ...Criteria) error // Delete deletes any candidate nodes that match the query criteria Delete() error @@ -88,8 +117,6 @@ type NodeQuery interface { // FetchKinds returns the ID and Kinds of matched nodes and omits property fetching FetchKinds(func(cursor Cursor[KindsResult]) error) error - - Debug() (string, map[string]any) } // RelationshipQuery is an interface that covers all supported relationship query combinations. The contract supports a @@ -124,8 +151,8 @@ type RelationshipQuery interface { // First completes the query and returns the result and any error encountered during execution. First() (*Relationship, error) - // Execute completes the query and hands the raw result to the given delegate for unmarshalling - Execute(delegate func(results Result) error, finalCriteria ...Criteria) error + // Query completes the query and hands the raw result to the given delegate for unmarshalling + Query(delegate func(results Result) error, finalCriteria ...Criteria) error // Fetch completes the query and captures a cursor for iterating the result set. This cursor is passed to the given // delegate. Errors from the delegate are returned upwards as the error result of this call. @@ -147,6 +174,4 @@ type RelationshipQuery interface { // FetchKinds returns the ID, Kind, Start ID and End ID of matched relationships and omits property fetching FetchKinds(delegate func(cursor Cursor[RelationshipKindsResult]) error) error - - Debug() (string, map[string]any) } diff --git a/packages/go/dawgs/graph/relationships.go b/packages/go/dawgs/graph/relationships.go index f2d5d275bf..6033c341b8 100644 --- a/packages/go/dawgs/graph/relationships.go +++ b/packages/go/dawgs/graph/relationships.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package graph @@ -29,7 +29,11 @@ type Relationship struct { Properties *Properties } -func (s Relationship) SizeOf() size.Size { +func (s *Relationship) Merge(other *Relationship) { + s.Properties.Merge(other.Properties) +} + +func (s *Relationship) SizeOf() size.Size { relSize := size.Of(s) + size.Of(s.Kind) if s.Properties != nil { diff --git a/packages/go/dawgs/graph/schema.go b/packages/go/dawgs/graph/schema.go index e583e8b924..0bc18c00e6 100644 --- a/packages/go/dawgs/graph/schema.go +++ b/packages/go/dawgs/graph/schema.go @@ -1,218 +1,61 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package graph -import ( - "fmt" - "strings" -) - type IndexType int const ( - UnsupportedIndex IndexType = 0 - BTreeIndex IndexType = 1 - FullTextSearchIndex IndexType = 2 + UnsupportedIndex IndexType = 0 + BTreeIndex IndexType = 1 + TextSearchIndex IndexType = 2 ) func (s IndexType) String() string { switch s { case BTreeIndex: - return "BTreeIndex" - case FullTextSearchIndex: - return "FullTextSearchIndex" - case UnsupportedIndex: - fallthrough - default: - return "UnsupportedIndex" - } -} - -type ConstraintSchema struct { - Name string - IndexType IndexType -} - -func (s ConstraintSchema) Equals(other ConstraintSchema) bool { - return s.Name == other.Name && s.IndexType == other.IndexType -} - -type IndexSchema struct { - Name string - IndexType IndexType -} - -func (s IndexSchema) Equals(other IndexSchema) bool { - return s.Name == other.Name && s.IndexType == other.IndexType -} - -type KindSchema struct { - Kind Kind - PropertyIndices map[string]IndexSchema - PropertyConstraints map[string]ConstraintSchema -} + return "btree" -func (s *KindSchema) Name() string { - return s.Kind.String() -} + case TextSearchIndex: + return "fts" -func (s *KindSchema) Constraint(property, name string, indexType IndexType) { - s.PropertyConstraints[property] = ConstraintSchema{ - Name: name, - IndexType: indexType, + default: + return "invalid" } } -func (s *KindSchema) ConstrainProperty(property string, indexType IndexType) { - s.Constraint(property, fmt.Sprintf("%s_%s_constraint", strings.ToLower(s.Name()), strings.ToLower(property)), indexType) -} - -func (s *KindSchema) Index(property, name string, indexType IndexType) { - s.PropertyIndices[property] = IndexSchema{ - Name: name, - IndexType: indexType, - } +type Index struct { + Name string + Field string + Type IndexType } -func (s *KindSchema) IndexProperty(property string, indexType IndexType) { - s.Index(property, fmt.Sprintf("%s_%s_index", strings.ToLower(s.Name()), strings.ToLower(property)), indexType) -} +type Constraint Index -type KindSchemaContinuation struct { - kinds []*KindSchema -} - -func (s KindSchemaContinuation) Constrain(name string, indexType IndexType) KindSchemaContinuation { - for _, label := range s.kinds { - label.ConstrainProperty(name, indexType) - } - - return s -} - -func (s KindSchemaContinuation) Index(name string, indexType IndexType) KindSchemaContinuation { - for _, label := range s.kinds { - label.IndexProperty(name, indexType) - } - - return s +type Graph struct { + Name string + Nodes Kinds + Edges Kinds + NodeConstraints []Constraint + EdgeConstraints []Constraint + NodeIndexes []Index + EdgeIndexes []Index } type Schema struct { - Kinds map[Kind]*KindSchema -} - -func NewSchema() *Schema { - return &Schema{ - Kinds: make(map[Kind]*KindSchema), - } -} - -func (s *Schema) ForKinds(kinds ...Kind) KindSchemaContinuation { - var selectedKinds []*KindSchema - - for _, kind := range kinds { - if kind, found := s.Kinds[kind]; found { - selectedKinds = append(selectedKinds, kind) - } - } - - return KindSchemaContinuation{ - kinds: selectedKinds, - } -} - -func (s *Schema) Kind(kind Kind) *KindSchema { - return s.Kinds[kind] -} - -func (s *Schema) EnsureKind(kind Kind) *KindSchema { - if label, found := s.Kinds[kind]; found { - return label - } else { - newLabel := &KindSchema{ - Kind: kind, - PropertyIndices: make(map[string]IndexSchema), - PropertyConstraints: make(map[string]ConstraintSchema), - } - - s.Kinds[kind] = newLabel - return newLabel - } -} - -func (s *Schema) DefineKinds(kinds ...Kind) { - for _, kind := range kinds { - s.Kinds[kind] = &KindSchema{ - Kind: kind, - PropertyIndices: make(map[string]IndexSchema), - PropertyConstraints: make(map[string]ConstraintSchema), - } - } -} - -func (s *Schema) ConstrainProperty(name string, indexType IndexType) { - for _, kindSchema := range s.Kinds { - kindSchema.PropertyConstraints[name] = ConstraintSchema{ - Name: fmt.Sprintf("%s_%s_constraint", strings.ToLower(kindSchema.Name()), strings.ToLower(name)), - IndexType: indexType, - } - } -} - -func (s *Schema) IndexProperty(name string, indexType IndexType) { - for _, labelSchema := range s.Kinds { - labelSchema.PropertyIndices[name] = IndexSchema{ - Name: fmt.Sprintf("%s_%s_index", strings.ToLower(labelSchema.Name()), strings.ToLower(name)), - IndexType: indexType, - } - } -} - -func (s *Schema) String() string { - output := strings.Builder{} - - for _, kindSchema := range s.Kinds { - output.WriteString("Label: ") - output.WriteString(kindSchema.Name()) - output.WriteRune('\n') - - for propertyName, constraint := range kindSchema.PropertyConstraints { - output.WriteString("\t") - output.WriteString(propertyName) - output.WriteString(" ") - output.WriteString(constraint.Name) - output.WriteString("[") - output.WriteString(constraint.IndexType.String()) - output.WriteString("]\n") - } - - for propertyName, index := range kindSchema.PropertyIndices { - output.WriteString("\t") - output.WriteString(propertyName) - output.WriteString(" ") - output.WriteString(index.Name) - output.WriteString("[") - output.WriteString(index.IndexType.String()) - output.WriteString("]\n") - } - - output.WriteRune('\n') - } - - return output.String() + Graphs []Graph + DefaultGraph Graph } diff --git a/packages/go/dawgs/graph/switch.go b/packages/go/dawgs/graph/switch.go new file mode 100644 index 0000000000..22fdc11194 --- /dev/null +++ b/packages/go/dawgs/graph/switch.go @@ -0,0 +1,182 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package graph + +import ( + "context" + "errors" + "sync" +) + +var ( + ErrAuthoritativeDatabaseSwitching = errors.New("switching authoritative database") +) + +type DatabaseSwitch struct { + activeContexts map[any]func() + currentDB Database + inSwitch bool + ctxLock *sync.Mutex + currentDBLock *sync.RWMutex + writeFlushSize int + batchWriteSize int +} + +func NewDatabaseSwitch(ctx context.Context, initialDB Database) *DatabaseSwitch { + return &DatabaseSwitch{ + activeContexts: map[any]func(){}, + currentDB: initialDB, + inSwitch: false, + ctxLock: &sync.Mutex{}, + currentDBLock: &sync.RWMutex{}, + } +} + +func (s *DatabaseSwitch) Switch(db Database) { + s.inSwitch = true + + defer func() { + s.inSwitch = false + }() + + s.cancelInternalContexts() + + s.currentDBLock.Lock() + defer s.currentDBLock.Unlock() + + s.currentDB = db +} + +func (s *DatabaseSwitch) SetWriteFlushSize(interval int) { + s.writeFlushSize = interval +} + +func (s *DatabaseSwitch) SetBatchWriteSize(interval int) { + s.batchWriteSize = interval +} + +func (s *DatabaseSwitch) newInternalContext(ctx context.Context) (context.Context, error) { + s.ctxLock.Lock() + defer s.ctxLock.Unlock() + + // Do not issue new contexts if we're in the process of switching authoritative databases + if s.inSwitch { + return nil, ErrAuthoritativeDatabaseSwitching + } + + internalCtx, doneFunc := context.WithCancel(ctx) + + s.activeContexts[internalCtx] = doneFunc + return internalCtx, nil +} + +func (s *DatabaseSwitch) cancelInternalContexts() { + s.ctxLock.Lock() + defer s.ctxLock.Unlock() + + for _, doneFunc := range s.activeContexts { + doneFunc() + } +} + +func (s *DatabaseSwitch) retireInternalContext(ctx context.Context) { + s.ctxLock.Lock() + defer s.ctxLock.Unlock() + + if doneFunc, exists := s.activeContexts[ctx]; exists { + doneFunc() + delete(s.activeContexts, ctx) + } +} + +func (s *DatabaseSwitch) ReadTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error { + if internalCtx, err := s.newInternalContext(ctx); err != nil { + return err + } else { + defer s.retireInternalContext(internalCtx) + + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.ReadTransaction(internalCtx, txDelegate, options...) + } +} + +func (s *DatabaseSwitch) WriteTransaction(ctx context.Context, txDelegate TransactionDelegate, options ...TransactionOption) error { + if internalCtx, err := s.newInternalContext(ctx); err != nil { + return err + } else { + defer s.retireInternalContext(internalCtx) + + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.WriteTransaction(internalCtx, txDelegate, options...) + } +} + +func (s *DatabaseSwitch) BatchOperation(ctx context.Context, batchDelegate BatchDelegate) error { + if internalCtx, err := s.newInternalContext(ctx); err != nil { + return err + } else { + defer s.retireInternalContext(internalCtx) + + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.BatchOperation(internalCtx, batchDelegate) + } +} + +func (s *DatabaseSwitch) AssertSchema(ctx context.Context, dbSchema Schema) error { + if internalCtx, err := s.newInternalContext(ctx); err != nil { + return err + } else { + defer s.retireInternalContext(internalCtx) + + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.AssertSchema(ctx, dbSchema) + } +} + +func (s *DatabaseSwitch) Run(ctx context.Context, query string, parameters map[string]any) error { + if internalCtx, err := s.newInternalContext(ctx); err != nil { + return err + } else { + defer s.retireInternalContext(internalCtx) + + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.Run(internalCtx, query, parameters) + } +} + +func (s *DatabaseSwitch) Close(ctx context.Context) error { + if internalCtx, err := s.newInternalContext(ctx); err != nil { + return err + } else { + defer s.retireInternalContext(internalCtx) + + s.currentDBLock.RLock() + defer s.currentDBLock.RUnlock() + + return s.currentDB.Close(ctx) + } +} diff --git a/packages/go/dawgs/graph/test/test.go b/packages/go/dawgs/graph/test/test.go index 89c04975d6..5cb4c98492 100644 --- a/packages/go/dawgs/graph/test/test.go +++ b/packages/go/dawgs/graph/test/test.go @@ -1,28 +1,40 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package test import ( + "github.com/stretchr/testify/require" "testing" + "time" - "github.com/stretchr/testify/require" "github.com/specterops/bloodhound/dawgs/graph" ) -func RequireProperty(t *testing.T, expected any, actual graph.PropertyValue) { - require.Equal(t, expected, actual.Any()) +func RequireProperty[T any](t *testing.T, expected T, actual graph.PropertyValue, msg ...any) { + var ( + value = actual.Any() + err error + ) + + switch any(expected).(type) { + case time.Time: + value, err = actual.Time() + } + + require.Nil(t, err, msg...) + require.Equal(t, expected, value, msg...) } diff --git a/packages/go/dawgs/ops/ops.go b/packages/go/dawgs/ops/ops.go index abb1ae4636..5d5cc81b50 100644 --- a/packages/go/dawgs/ops/ops.go +++ b/packages/go/dawgs/ops/ops.go @@ -27,6 +27,18 @@ import ( "github.com/specterops/bloodhound/dawgs/util/size" ) +func FetchAllNodeProperties(tx graph.Transaction, nodes graph.NodeSet) error { + return tx.Nodes().Filter( + query.InIDs(query.NodeID(), nodes.IDs()...), + ).Fetch(func(cursor graph.Cursor[*graph.Node]) error { + for next := range cursor.Chan() { + nodes[next.ID] = next + } + + return cursor.Error() + }) +} + func FetchNodeProperties(tx graph.Transaction, nodes graph.NodeSet, propertyNames []string) error { returningCriteria := make([]graph.Criteria, len(propertyNames)+1) returningCriteria[0] = query.NodeID() @@ -37,31 +49,32 @@ func FetchNodeProperties(tx graph.Transaction, nodes graph.NodeSet, propertyName return tx.Nodes().Filter( query.InIDs(query.NodeID(), nodes.IDs()...), - ).Execute(func(results graph.Result) error { + ).Query(func(results graph.Result) error { var nodeID graph.ID for results.Next() { - var ( - mapper = results.Values() - nodeProperties = map[string]any{} - ) - - // Map the node ID first - if err := mapper.Map(&nodeID); err != nil { + if values, err := results.Values(); err != nil { return err - } + } else { + nodeProperties := map[string]any{} - // Map requested properties next by matching the name index - for idx := 0; idx < len(propertyNames); idx++ { - if next, err := mapper.Next(); err != nil { + // Map the node ID first + if err := values.Map(&nodeID); err != nil { return err - } else { - nodeProperties[propertyNames[idx]] = next } - } - // Update the node in the node set - nodes[nodeID].Properties = graph.AsProperties(nodeProperties) + // Map requested properties next by matching the name index + for idx := 0; idx < len(propertyNames); idx++ { + if next, err := values.Next(); err != nil { + return err + } else { + nodeProperties[propertyNames[idx]] = next + } + } + + // Update the node in the node set + nodes[nodeID].Properties = graph.AsProperties(nodeProperties) + } } return nil @@ -146,7 +159,7 @@ func FetchPathSetByQuery(tx graph.Transaction, query string) (graph.PathSet, err pathSet graph.PathSet ) - if result := tx.Run(query, map[string]any{}); result.Error() != nil { + if result := tx.Query(query, map[string]any{}); result.Error() != nil { return pathSet, result.Error() } else { defer result.Close() @@ -158,7 +171,9 @@ func FetchPathSetByQuery(tx graph.Transaction, query string) (graph.PathSet, err path graph.Path ) - if mapped, err := result.Values().MapOptions(&relationship, &node, &path); err != nil { + if values, err := result.Values(); err != nil { + return pathSet, err + } else if mapped, err := values.MapOptions(&relationship, &node, &path); err != nil { return pathSet, err } else { switch typedMapped := mapped.(type) { @@ -323,22 +338,32 @@ func FetchRelationshipIDs(query graph.RelationshipQuery) ([]graph.ID, error) { }) } -func FetchPathSet(tx graph.Transaction, query graph.RelationshipQuery) (graph.PathSet, error) { +func FetchPathSet(queryInst graph.RelationshipQuery) (graph.PathSet, error) { pathSet := graph.NewPathSet() - return pathSet, query.Fetch(func(cursor graph.Cursor[*graph.Relationship]) error { - for rel := range cursor.Chan() { - if start, end, err := FetchRelationshipNodes(tx, rel); err != nil { + + return pathSet, queryInst.Query(func(results graph.Result) error { + defer results.Close() + + for results.Next() { + var ( + start, end graph.Node + edge graph.Relationship + ) + + if err := results.Scan(&start, &edge, &end); err != nil { return err } else { pathSet.AddPath(graph.Path{ - Nodes: []*graph.Node{start, end}, - Edges: []*graph.Relationship{rel}, + Nodes: []*graph.Node{&start, &end}, + Edges: []*graph.Relationship{&edge}, }) } } - return nil - }) + return results.Error() + }, query.Returning( + query.Start(), query.Relationship(), query.End(), + )) } func FetchRelationshipNodes(tx graph.Transaction, relationship *graph.Relationship) (*graph.Node, *graph.Node, error) { diff --git a/packages/go/dawgs/ops/traversal.go b/packages/go/dawgs/ops/traversal.go index 9eea6a316d..d6d3a2e85e 100644 --- a/packages/go/dawgs/ops/traversal.go +++ b/packages/go/dawgs/ops/traversal.go @@ -19,6 +19,7 @@ package ops import ( "errors" "fmt" + "github.com/specterops/bloodhound/log" "github.com/RoaringBitmap/roaring/roaring64" @@ -155,6 +156,8 @@ func AcyclicTraversal(tx graph.Transaction, plan TraversalPlan, pathVisitors ... } func Traversal(tx graph.Transaction, plan TraversalPlan, pathVisitors ...PathVisitor) error { + defer log.Measure(log.LevelInfo, "Node %d Traversal", plan.Root.ID)() + var ( pathVisitor PathVisitor requireTraversalOrder = plan.Limit > 0 || plan.Skip > 0 diff --git a/packages/go/dawgs/query/builder.go b/packages/go/dawgs/query/builder.go new file mode 100644 index 0000000000..4af376d290 --- /dev/null +++ b/packages/go/dawgs/query/builder.go @@ -0,0 +1,285 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package query + +import ( + "errors" + "fmt" + "github.com/specterops/bloodhound/cypher/model" + "github.com/specterops/bloodhound/dawgs/graph" +) + +var ( + ErrAmbiguousQueryVariables = errors.New("query mixes node and relationship query variables") +) + +type Cache struct { +} + +type Builder struct { + regularQuery *model.RegularQuery + cache *Cache +} + +func NewBuilder(cache *Cache) *Builder { + return &Builder{ + regularQuery: EmptySinglePartQuery(), + cache: cache, + } +} + +func NewBuilderWithCriteria(criteria ...graph.Criteria) *Builder { + builder := NewBuilder(nil) + builder.Apply(criteria...) + + return builder +} + +func (s *Builder) RegularQuery() *model.RegularQuery { + return s.regularQuery +} + +func (s *Builder) Build() (*model.RegularQuery, error) { + if err := s.prepareMatch(); err != nil { + return nil, err + } + + return s.regularQuery, nil +} + +func (s *Builder) prepareMatch() error { + var ( + patternPart = &model.PatternPart{} + + singleNodeBound = false + creatingSingleNode = false + + startNodeBound = false + creatingStartNode = false + endNodeBound = false + creatingEndNode = false + edgeBound = false + creatingEdge = false + + isRelationshipQuery = false + + bindWalk = model.NewVisitor(func(stack *model.WalkStack, branch model.Expression) error { + switch typedElement := branch.(type) { + case *model.Variable: + switch typedElement.Symbol { + case NodeSymbol: + singleNodeBound = true + + case EdgeStartSymbol: + startNodeBound = true + isRelationshipQuery = true + + case EdgeEndSymbol: + endNodeBound = true + isRelationshipQuery = true + + case EdgeSymbol: + edgeBound = true + isRelationshipQuery = true + } + } + + return nil + }, nil) + ) + + // Zip through updating clauses first + for _, updatingClause := range s.regularQuery.SingleQuery.SinglePartQuery.UpdatingClauses { + typedUpdatingClause, isUpdatingClause := updatingClause.(*model.UpdatingClause) + + if !isUpdatingClause { + return fmt.Errorf("unexpected type for updating clause: %T", updatingClause) + } + + switch typedClause := typedUpdatingClause.Clause.(type) { + case *model.Create: + if err := model.Walk(typedClause, model.NewVisitor(func(stack *model.WalkStack, element model.Expression) error { + switch typedElement := element.(type) { + case *model.NodePattern: + if patternBinding, typeOK := typedElement.Binding.(*model.Variable); !typeOK { + return fmt.Errorf("expected variable for pattern binding but got: %T", typedElement.Binding) + } else { + switch patternBinding.Symbol { + case NodeSymbol: + creatingSingleNode = true + + case EdgeStartSymbol: + creatingStartNode = true + + case EdgeEndSymbol: + creatingEndNode = true + } + } + + case *model.RelationshipPattern: + if patternBinding, typeOK := typedElement.Binding.(*model.Variable); !typeOK { + return fmt.Errorf("expected variable for pattern binding but got: %T", typedElement.Binding) + } else { + switch patternBinding.Symbol { + case EdgeSymbol: + creatingEdge = true + } + } + } + + return nil + }, nil)); err != nil { + return err + } + + case *model.Delete: + if err := model.Walk(typedClause, bindWalk); err != nil { + return err + } + } + } + + // Is there a where clause? + if firstReadingClause := GetFirstReadingClause(s.regularQuery); firstReadingClause != nil && firstReadingClause.Match.Where != nil { + if err := model.Walk(firstReadingClause.Match.Where, bindWalk); err != nil { + return err + } + } + + // Is there a return clause + if s.regularQuery.SingleQuery.SinglePartQuery.Return != nil { + if err := model.Walk(s.regularQuery.SingleQuery.SinglePartQuery.Return, bindWalk); err != nil { + return err + } + } + + // Validate we're not mixing references + if isRelationshipQuery && singleNodeBound { + return ErrAmbiguousQueryVariables + } + + if singleNodeBound && !creatingSingleNode { + // Bind the single-node variable + patternPart.AddPatternElements(&model.NodePattern{ + Binding: model.NewVariableWithSymbol(NodeSymbol), + }) + } + + if startNodeBound { + // Bind the start-node variable + patternPart.AddPatternElements(&model.NodePattern{ + Binding: model.NewVariableWithSymbol(EdgeStartSymbol), + }) + } + + if isRelationshipQuery { + if !startNodeBound && !creatingStartNode { + // Add an empty node pattern if the start node isn't bound, and we aren't creating it + patternPart.AddPatternElements(&model.NodePattern{}) + } + + if !creatingEdge { + if edgeBound { + // Bind the edge variable + patternPart.AddPatternElements(&model.RelationshipPattern{ + Binding: model.NewVariableWithSymbol(EdgeSymbol), + Direction: graph.DirectionOutbound, + }) + } else { + patternPart.AddPatternElements(&model.RelationshipPattern{ + Direction: graph.DirectionOutbound, + }) + } + } + + if !endNodeBound && !creatingEndNode { + patternPart.AddPatternElements(&model.NodePattern{}) + } + } + + if endNodeBound { + // Add an empty node pattern if the end node isn't bound, and we aren't creating it + patternPart.AddPatternElements(&model.NodePattern{ + Binding: model.NewVariableWithSymbol(EdgeEndSymbol), + }) + } + + if firstReadingClause := GetFirstReadingClause(s.regularQuery); firstReadingClause != nil { + firstReadingClause.Match.Pattern = []*model.PatternPart{patternPart} + } else if len(patternPart.PatternElements) > 0 { + s.regularQuery.SingleQuery.SinglePartQuery.AddReadingClause(&model.ReadingClause{ + Match: &model.Match{ + Pattern: []*model.PatternPart{ + patternPart, + }, + }, + }) + } + + return nil +} + +func (s *Builder) Apply(criteria ...graph.Criteria) { + for _, nextCriteria := range criteria { + switch typedCriteria := nextCriteria.(type) { + case []graph.Criteria: + s.Apply(typedCriteria...) + + case *model.Where: + firstReadingClause := GetFirstReadingClause(s.regularQuery) + + if firstReadingClause == nil { + firstReadingClause = &model.ReadingClause{ + Match: model.NewMatch(false), + } + + s.regularQuery.SingleQuery.SinglePartQuery.AddReadingClause(firstReadingClause) + } + + firstReadingClause.Match.Where = model.Copy(typedCriteria) + + case *model.Return: + s.regularQuery.SingleQuery.SinglePartQuery.Return = typedCriteria + + case *model.Limit: + if s.regularQuery.SingleQuery.SinglePartQuery.Return != nil { + s.regularQuery.SingleQuery.SinglePartQuery.Return.Projection.Limit = model.Copy(typedCriteria) + } + + case *model.Skip: + if s.regularQuery.SingleQuery.SinglePartQuery.Return != nil { + s.regularQuery.SingleQuery.SinglePartQuery.Return.Projection.Skip = model.Copy(typedCriteria) + } + + case *model.Order: + if s.regularQuery.SingleQuery.SinglePartQuery.Return != nil { + s.regularQuery.SingleQuery.SinglePartQuery.Return.Projection.Order = model.Copy(typedCriteria) + } + + case []*model.UpdatingClause: + for _, updatingClause := range typedCriteria { + s.Apply(updatingClause) + } + + case *model.UpdatingClause: + s.regularQuery.SingleQuery.SinglePartQuery.AddUpdatingClause(model.Copy(typedCriteria)) + + default: + panic(fmt.Errorf("invalid type for dawgs query: %T %+v", criteria, criteria)) + } + } +} diff --git a/packages/go/dawgs/query/identifiers.go b/packages/go/dawgs/query/identifiers.go index 2b0fd8226d..090c011ed6 100644 --- a/packages/go/dawgs/query/identifiers.go +++ b/packages/go/dawgs/query/identifiers.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package query @@ -26,19 +26,19 @@ func Variable(name string) *model.Variable { } } -func Identity(entity *model.Variable) *model.FunctionInvocation { +func Identity(entity model.Expression) *model.FunctionInvocation { return &model.FunctionInvocation{ Name: "id", - Arguments: convertCriteria[model.Expression](entity), + Arguments: []model.Expression{entity}, } } const ( - PathSymbol = "p" - NodeSymbol = "n" - RelationshipSymbol = "r" - RelationshipStartSymbol = "s" - RelationshipEndSymbol = "e" + PathSymbol = "p" + NodeSymbol = "n" + EdgeSymbol = "r" + EdgeStartSymbol = "s" + EdgeEndSymbol = "e" ) func Node() *model.Variable { @@ -50,7 +50,7 @@ func NodeID() *model.FunctionInvocation { } func Relationship() *model.Variable { - return Variable(RelationshipSymbol) + return Variable(EdgeSymbol) } func RelationshipID() *model.FunctionInvocation { @@ -58,7 +58,7 @@ func RelationshipID() *model.FunctionInvocation { } func Start() *model.Variable { - return Variable(RelationshipStartSymbol) + return Variable(EdgeStartSymbol) } func StartID() *model.FunctionInvocation { @@ -66,7 +66,7 @@ func StartID() *model.FunctionInvocation { } func End() *model.Variable { - return Variable(RelationshipEndSymbol) + return Variable(EdgeEndSymbol) } func EndID() *model.FunctionInvocation { diff --git a/packages/go/dawgs/query/model.go b/packages/go/dawgs/query/model.go index 1c59b21b62..6744506f70 100644 --- a/packages/go/dawgs/query/model.go +++ b/packages/go/dawgs/query/model.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package query @@ -21,7 +21,7 @@ import ( "strings" "time" - "github.com/specterops/bloodhound/cypher/model" + cypherModel "github.com/specterops/bloodhound/cypher/model" "github.com/specterops/bloodhound/dawgs/graph" ) @@ -37,52 +37,52 @@ func convertCriteria[T any](criteria ...graph.Criteria) []T { return converted } -func Update(clauses ...*model.UpdatingClause) []*model.UpdatingClause { +func Update(clauses ...*cypherModel.UpdatingClause) []*cypherModel.UpdatingClause { return clauses } -func Updatef(provider graph.CriteriaProvider) []*model.UpdatingClause { +func Updatef(provider graph.CriteriaProvider) []*cypherModel.UpdatingClause { switch typedCriteria := provider().(type) { - case []*model.UpdatingClause: + case []*cypherModel.UpdatingClause: return typedCriteria case []graph.Criteria: - return convertCriteria[*model.UpdatingClause](typedCriteria...) + return convertCriteria[*cypherModel.UpdatingClause](typedCriteria...) - case *model.UpdatingClause: - return []*model.UpdatingClause{typedCriteria} + case *cypherModel.UpdatingClause: + return []*cypherModel.UpdatingClause{typedCriteria} default: - return []*model.UpdatingClause{ - model.WithErrors(&model.UpdatingClause{}, fmt.Errorf("invalid type %T for update clause", typedCriteria)), + return []*cypherModel.UpdatingClause{ + cypherModel.WithErrors(&cypherModel.UpdatingClause{}, fmt.Errorf("invalid type %T for update clause", typedCriteria)), } } } -func AddKind(reference graph.Criteria, kind graph.Kind) *model.UpdatingClause { - return model.NewUpdatingClause(&model.Set{ - Items: []*model.SetItem{{ +func AddKind(reference graph.Criteria, kind graph.Kind) *cypherModel.UpdatingClause { + return cypherModel.NewUpdatingClause(&cypherModel.Set{ + Items: []*cypherModel.SetItem{{ Left: reference, - Operator: model.OperatorLabelAssignment, + Operator: cypherModel.OperatorLabelAssignment, Right: graph.Kinds{kind}, }}, }) } -func AddKinds(reference graph.Criteria, kinds graph.Kinds) *model.UpdatingClause { - return model.NewUpdatingClause(&model.Set{ - Items: []*model.SetItem{{ +func AddKinds(reference graph.Criteria, kinds graph.Kinds) *cypherModel.UpdatingClause { + return cypherModel.NewUpdatingClause(&cypherModel.Set{ + Items: []*cypherModel.SetItem{{ Left: reference, - Operator: model.OperatorLabelAssignment, + Operator: cypherModel.OperatorLabelAssignment, Right: kinds, }}, }) } -func DeleteKind(reference graph.Criteria, kind graph.Kind) *model.UpdatingClause { - return model.NewUpdatingClause(&model.Remove{ - Items: []*model.RemoveItem{{ - KindMatcher: &model.KindMatcher{ +func DeleteKind(reference graph.Criteria, kind graph.Kind) *cypherModel.UpdatingClause { + return cypherModel.NewUpdatingClause(&cypherModel.Remove{ + Items: []*cypherModel.RemoveItem{{ + KindMatcher: &cypherModel.KindMatcher{ Reference: reference, Kinds: graph.Kinds{kind}, }, @@ -90,10 +90,10 @@ func DeleteKind(reference graph.Criteria, kind graph.Kind) *model.UpdatingClause }) } -func DeleteKinds(reference graph.Criteria, kinds graph.Kinds) *model.Remove { - return &model.Remove{ - Items: []*model.RemoveItem{{ - KindMatcher: &model.KindMatcher{ +func DeleteKinds(reference graph.Criteria, kinds graph.Kinds) *cypherModel.Remove { + return &cypherModel.Remove{ + Items: []*cypherModel.RemoveItem{{ + KindMatcher: &cypherModel.KindMatcher{ Reference: reference, Kinds: kinds, }, @@ -101,58 +101,58 @@ func DeleteKinds(reference graph.Criteria, kinds graph.Kinds) *model.Remove { } } -func SetProperty(reference graph.Criteria, value any) *model.UpdatingClause { - return model.NewUpdatingClause(&model.Set{ - Items: []*model.SetItem{{ +func SetProperty(reference graph.Criteria, value any) *cypherModel.UpdatingClause { + return cypherModel.NewUpdatingClause(&cypherModel.Set{ + Items: []*cypherModel.SetItem{{ Left: reference, - Operator: model.OperatorAssignment, + Operator: cypherModel.OperatorAssignment, Right: Parameter(value), }}, }) } -func SetProperties(reference graph.Criteria, properties map[string]any) *model.UpdatingClause { - set := &model.Set{} +func SetProperties(reference graph.Criteria, properties map[string]any) *cypherModel.UpdatingClause { + set := &cypherModel.Set{} for key, value := range properties { - set.Items = append(set.Items, &model.SetItem{ + set.Items = append(set.Items, &cypherModel.SetItem{ Left: Property(reference, key), - Operator: model.OperatorAssignment, + Operator: cypherModel.OperatorAssignment, Right: Parameter(value), }) } - return model.NewUpdatingClause(set) + return cypherModel.NewUpdatingClause(set) } -func DeleteProperty(reference *model.PropertyLookup) *model.UpdatingClause { - return model.NewUpdatingClause(&model.Remove{ - Items: []*model.RemoveItem{{ +func DeleteProperty(reference *cypherModel.PropertyLookup) *cypherModel.UpdatingClause { + return cypherModel.NewUpdatingClause(&cypherModel.Remove{ + Items: []*cypherModel.RemoveItem{{ Property: reference, }}, }) } -func DeleteProperties(reference graph.Criteria, propertyNames ...string) *model.UpdatingClause { - removeClause := &model.Remove{} +func DeleteProperties(reference graph.Criteria, propertyNames ...string) *cypherModel.UpdatingClause { + removeClause := &cypherModel.Remove{} for _, propertyName := range propertyNames { - removeClause.Items = append(removeClause.Items, &model.RemoveItem{ + removeClause.Items = append(removeClause.Items, &cypherModel.RemoveItem{ Property: Property(reference, propertyName), }) } - return model.NewUpdatingClause(removeClause) + return cypherModel.NewUpdatingClause(removeClause) } -func Kind(reference graph.Criteria, kind graph.Kind) *model.KindMatcher { - return &model.KindMatcher{ +func Kind(reference graph.Criteria, kind graph.Kind) *cypherModel.KindMatcher { + return &cypherModel.KindMatcher{ Reference: reference, Kinds: graph.Kinds{kind}, } } -func KindIn(reference graph.Criteria, kinds ...graph.Kind) *model.Parenthetical { +func KindIn(reference graph.Criteria, kinds ...graph.Kind) *cypherModel.Parenthetical { expressions := make([]graph.Criteria, len(kinds)) for idx, kind := range kinds { @@ -162,320 +162,329 @@ func KindIn(reference graph.Criteria, kinds ...graph.Kind) *model.Parenthetical return Or(expressions...) } -func NodeProperty(name string) *model.PropertyLookup { - return model.NewPropertyLookup(NodeSymbol, name) +func NodeProperty(name string) *cypherModel.PropertyLookup { + return cypherModel.NewPropertyLookup(NodeSymbol, name) } -func RelationshipProperty(name string) *model.PropertyLookup { - return model.NewPropertyLookup(RelationshipSymbol, name) +func RelationshipProperty(name string) *cypherModel.PropertyLookup { + return cypherModel.NewPropertyLookup(EdgeSymbol, name) } -func StartProperty(name string) *model.PropertyLookup { - return model.NewPropertyLookup(RelationshipStartSymbol, name) +func StartProperty(name string) *cypherModel.PropertyLookup { + return cypherModel.NewPropertyLookup(EdgeStartSymbol, name) } -func EndProperty(name string) *model.PropertyLookup { - return model.NewPropertyLookup(RelationshipEndSymbol, name) +func EndProperty(name string) *cypherModel.PropertyLookup { + return cypherModel.NewPropertyLookup(EdgeEndSymbol, name) } -func Property(qualifier graph.Criteria, name string) *model.PropertyLookup { - return &model.PropertyLookup{ - Atom: qualifier.(*model.Variable), +func Property(qualifier graph.Criteria, name string) *cypherModel.PropertyLookup { + return &cypherModel.PropertyLookup{ + Atom: qualifier.(*cypherModel.Variable), Symbols: []string{name}, } } -func Count(reference graph.Criteria) graph.Criteria { - return &model.FunctionInvocation{ +func Count(reference graph.Criteria) *cypherModel.FunctionInvocation { + return &cypherModel.FunctionInvocation{ Name: "count", - Arguments: []model.Expression{reference}, + Arguments: []cypherModel.Expression{reference}, } } -func And(criteria ...graph.Criteria) *model.Conjunction { - return model.NewConjunction(convertCriteria[model.Expression](criteria...)...) +func CountDistinct(reference graph.Criteria) *cypherModel.FunctionInvocation { + return &cypherModel.FunctionInvocation{ + Name: "count", + Distinct: true, + Arguments: []cypherModel.Expression{reference}, + } +} + +func And(criteria ...graph.Criteria) *cypherModel.Conjunction { + return cypherModel.NewConjunction(convertCriteria[cypherModel.Expression](criteria...)...) } -func Or(criteria ...graph.Criteria) *model.Parenthetical { - return &model.Parenthetical{ - Expression: model.NewDisjunction(convertCriteria[model.Expression](criteria...)...), +func Or(criteria ...graph.Criteria) *cypherModel.Parenthetical { + return &cypherModel.Parenthetical{ + Expression: cypherModel.NewDisjunction(convertCriteria[cypherModel.Expression](criteria...)...), } } -func Xor(criteria ...graph.Criteria) *model.ExclusiveDisjunction { - return model.NewExclusiveDisjunction(convertCriteria[model.Expression](criteria...)...) +func Xor(criteria ...graph.Criteria) *cypherModel.ExclusiveDisjunction { + return cypherModel.NewExclusiveDisjunction(convertCriteria[cypherModel.Expression](criteria...)...) } -func Parameter(value any) *model.Parameter { - if parameter, isParameter := value.(*model.Parameter); isParameter { +func Parameter(value any) *cypherModel.Parameter { + if parameter, isParameter := value.(*cypherModel.Parameter); isParameter { return parameter } - return &model.Parameter{ + return &cypherModel.Parameter{ Value: value, } } -func Literal(value any) *model.Literal { - return &model.Literal{ +func Literal(value any) *cypherModel.Literal { + return &cypherModel.Literal{ Value: value, Null: value == nil, } } -func KindsOf(ref graph.Criteria) *model.FunctionInvocation { +func KindsOf(ref graph.Criteria) *cypherModel.FunctionInvocation { switch typedRef := ref.(type) { - case *model.Variable: + case *cypherModel.Variable: switch typedRef.Symbol { - case NodeSymbol, RelationshipStartSymbol, RelationshipEndSymbol: - return &model.FunctionInvocation{ + case NodeSymbol, EdgeStartSymbol, EdgeEndSymbol: + return &cypherModel.FunctionInvocation{ Name: "labels", - Arguments: []model.Expression{ref}, + Arguments: []cypherModel.Expression{ref}, } - case RelationshipSymbol: - return &model.FunctionInvocation{ + case EdgeSymbol: + return &cypherModel.FunctionInvocation{ Name: "type", - Arguments: []model.Expression{ref}, + Arguments: []cypherModel.Expression{ref}, } default: - return model.WithErrors(&model.FunctionInvocation{}, fmt.Errorf("invalid variable reference for KindsOf: %s", typedRef.Symbol)) + return cypherModel.WithErrors(&cypherModel.FunctionInvocation{}, fmt.Errorf("invalid variable reference for KindsOf: %s", typedRef.Symbol)) } default: - return model.WithErrors(&model.FunctionInvocation{}, fmt.Errorf("invalid reference type for KindsOf: %T", ref)) + return cypherModel.WithErrors(&cypherModel.FunctionInvocation{}, fmt.Errorf("invalid reference type for KindsOf: %T", ref)) } } -func Limit(limit int) *model.Limit { - return &model.Limit{ +func Limit(limit int) *cypherModel.Limit { + return &cypherModel.Limit{ Value: Literal(limit), } } -func Offset(offset int) *model.Skip { - return &model.Skip{ +func Offset(offset int) *cypherModel.Skip { + return &cypherModel.Skip{ Value: Literal(offset), } } -func StringContains(reference graph.Criteria, value string) *model.Comparison { - return model.NewComparison(reference, model.OperatorContains, Parameter(value)) +func StringContains(reference graph.Criteria, value string) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorContains, Parameter(value)) } -func StringStartsWith(reference graph.Criteria, value string) *model.Comparison { - return model.NewComparison(reference, model.OperatorStartsWith, Parameter(value)) +func StringStartsWith(reference graph.Criteria, value string) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorStartsWith, Parameter(value)) } -func StringEndsWith(reference graph.Criteria, value string) *model.Comparison { - return model.NewComparison(reference, model.OperatorEndsWith, Parameter(value)) +func StringEndsWith(reference graph.Criteria, value string) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorEndsWith, Parameter(value)) } -func CaseInsensitiveStringContains(reference graph.Criteria, value string) *model.Comparison { - return model.NewComparison( - model.NewSimpleFunctionInvocation("toLower", convertCriteria[model.Expression](reference)...), - model.OperatorContains, +func CaseInsensitiveStringContains(reference graph.Criteria, value string) *cypherModel.Comparison { + return cypherModel.NewComparison( + cypherModel.NewSimpleFunctionInvocation("toLower", convertCriteria[cypherModel.Expression](reference)...), + cypherModel.OperatorContains, Parameter(strings.ToLower(value)), ) } -func CaseInsensitiveStringStartsWith(reference graph.Criteria, value string) *model.Comparison { - return model.NewComparison( - model.NewSimpleFunctionInvocation("toLower", convertCriteria[model.Expression](reference)...), - model.OperatorStartsWith, +func CaseInsensitiveStringStartsWith(reference graph.Criteria, value string) *cypherModel.Comparison { + return cypherModel.NewComparison( + cypherModel.NewSimpleFunctionInvocation("toLower", convertCriteria[cypherModel.Expression](reference)...), + cypherModel.OperatorStartsWith, Parameter(strings.ToLower(value)), ) } -func CaseInsensitiveStringEndsWith(reference graph.Criteria, value string) *model.Comparison { - return model.NewComparison( - model.NewSimpleFunctionInvocation("toLower", convertCriteria[model.Expression](reference)...), - model.OperatorEndsWith, +func CaseInsensitiveStringEndsWith(reference graph.Criteria, value string) *cypherModel.Comparison { + return cypherModel.NewComparison( + cypherModel.NewSimpleFunctionInvocation("toLower", convertCriteria[cypherModel.Expression](reference)...), + cypherModel.OperatorEndsWith, Parameter(strings.ToLower(value)), ) } -func Equals(reference graph.Criteria, value any) *model.Comparison { - return model.NewComparison(reference, model.OperatorEquals, Parameter(value)) +func Equals(reference graph.Criteria, value any) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorEquals, Parameter(value)) } -func GreaterThan(reference graph.Criteria, value any) *model.Comparison { - return model.NewComparison(reference, model.OperatorGreaterThan, Parameter(value)) +func GreaterThan(reference graph.Criteria, value any) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorGreaterThan, Parameter(value)) } -func After(reference graph.Criteria, value any) *model.Comparison { +func After(reference graph.Criteria, value any) *cypherModel.Comparison { return GreaterThan(reference, value) } -func GreaterThanOrEquals(reference graph.Criteria, value any) *model.Comparison { - return model.NewComparison(reference, model.OperatorGreaterThanOrEqualTo, Parameter(value)) +func GreaterThanOrEquals(reference graph.Criteria, value any) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorGreaterThanOrEqualTo, Parameter(value)) } -func LessThan(reference graph.Criteria, value any) *model.Comparison { - return model.NewComparison(reference, model.OperatorLessThan, Parameter(value)) +func LessThan(reference graph.Criteria, value any) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorLessThan, Parameter(value)) } -func Before(reference graph.Criteria, value time.Time) *model.Comparison { +func Before(reference graph.Criteria, value time.Time) *cypherModel.Comparison { return LessThan(reference, value) } -func LessThanOrEquals(reference graph.Criteria, value any) *model.Comparison { - return model.NewComparison(reference, model.OperatorLessThanOrEqualTo, Parameter(value)) +func LessThanOrEquals(reference graph.Criteria, value any) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorLessThanOrEqualTo, Parameter(value)) } -func Distinct(reference graph.Criteria) *model.FunctionInvocation { - return model.NewSimpleFunctionInvocation("distinct", reference) +func Exists(reference graph.Criteria) *cypherModel.Comparison { + return cypherModel.NewComparison( + reference, + cypherModel.OperatorIsNot, + cypherModel.NewLiteral(nil, true), + ) } -func Exists(reference graph.Criteria) *model.FunctionInvocation { - return model.NewSimpleFunctionInvocation("exists", reference) -} +func HasRelationships(reference *cypherModel.Variable) *cypherModel.PatternPredicate { + patternPredicate := cypherModel.NewPatternPredicate() -func HasRelationships(reference *model.Variable) graph.Criteria { - return []*model.PatternPart{ - model.NewPatternPart().AddPatternElements( - &model.NodePattern{ - Binding: reference.Symbol, - }, - &model.RelationshipPattern{ - Direction: graph.DirectionBoth, - }, - &model.NodePattern{}, - ), - } + patternPredicate.AddElement(&cypherModel.NodePattern{ + Binding: cypherModel.NewVariableWithSymbol(reference.Symbol), + }) + + patternPredicate.AddElement(&cypherModel.RelationshipPattern{ + Direction: graph.DirectionBoth, + }) + + patternPredicate.AddElement(&cypherModel.NodePattern{}) + + return patternPredicate } -func In(reference graph.Criteria, value any) *model.Comparison { - return model.NewComparison(reference, model.OperatorIn, Parameter(value)) +func In(reference graph.Criteria, value any) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorIn, Parameter(value)) } -func InIDs[T *model.FunctionInvocation | *model.Variable](reference T, ids ...graph.ID) *model.Comparison { +func InIDs[T *cypherModel.FunctionInvocation | *cypherModel.Variable](reference T, ids ...graph.ID) *cypherModel.Comparison { switch any(reference).(type) { - case *model.FunctionInvocation: - return model.NewComparison(reference, model.OperatorIn, Parameter(ids)) + case *cypherModel.FunctionInvocation: + return cypherModel.NewComparison(reference, cypherModel.OperatorIn, Parameter(ids)) default: - return model.NewComparison(Identity(any(reference).(*model.Variable)), model.OperatorIn, Parameter(ids)) + return cypherModel.NewComparison(Identity(any(reference).(*cypherModel.Variable)), cypherModel.OperatorIn, Parameter(ids)) } } -func Where(expression graph.Criteria) *model.Where { - return &model.Where{ - JoiningExpression: model.JoiningExpression{ - Expressions: convertCriteria[model.Expression](expression), - }, - } +func Where(expression graph.Criteria) *cypherModel.Where { + whereClause := cypherModel.NewWhere() + whereClause.AddSlice(convertCriteria[cypherModel.Expression](expression)) + + return whereClause } -func OrderBy(leaves ...graph.Criteria) *model.Order { - return &model.Order{ - Items: convertCriteria[*model.SortItem](leaves...), +func OrderBy(leaves ...graph.Criteria) *cypherModel.Order { + return &cypherModel.Order{ + Items: convertCriteria[*cypherModel.SortItem](leaves...), } } -func Order(reference, direction graph.Criteria) *model.SortItem { +func Order(reference, direction graph.Criteria) *cypherModel.SortItem { switch direction { - case model.SortDescending: - return &model.SortItem{ + case cypherModel.SortDescending: + return &cypherModel.SortItem{ Ascending: false, Expression: reference, } default: - return &model.SortItem{ + return &cypherModel.SortItem{ Ascending: true, Expression: reference, } } } -func Ascending() model.SortOrder { - return model.SortAscending +func Ascending() cypherModel.SortOrder { + return cypherModel.SortAscending } -func Descending() model.SortOrder { - return model.SortDescending +func Descending() cypherModel.SortOrder { + return cypherModel.SortDescending } -func Delete(leaves ...graph.Criteria) *model.UpdatingClause { - deleteClause := &model.Delete{ +func Delete(leaves ...graph.Criteria) *cypherModel.UpdatingClause { + deleteClause := &cypherModel.Delete{ Detach: true, } for _, leaf := range leaves { - switch leaf.(*model.Variable).Symbol { - case RelationshipSymbol, RelationshipStartSymbol, RelationshipEndSymbol: + switch leaf.(*cypherModel.Variable).Symbol { + case EdgeSymbol, EdgeStartSymbol, EdgeEndSymbol: deleteClause.Detach = false } deleteClause.Expressions = append(deleteClause.Expressions, leaf) } - return model.NewUpdatingClause(deleteClause) + return cypherModel.NewUpdatingClause(deleteClause) } -func NodePattern(kinds graph.Kinds, properties *model.Parameter) *model.NodePattern { - return &model.NodePattern{ - Binding: NodeSymbol, +func NodePattern(kinds graph.Kinds, properties *cypherModel.Parameter) *cypherModel.NodePattern { + return &cypherModel.NodePattern{ + Binding: cypherModel.NewVariableWithSymbol(NodeSymbol), Kinds: kinds, Properties: properties, } } -func StartNodePattern(kinds graph.Kinds, properties *model.Parameter) *model.NodePattern { - return &model.NodePattern{ - Binding: RelationshipStartSymbol, +func StartNodePattern(kinds graph.Kinds, properties *cypherModel.Parameter) *cypherModel.NodePattern { + return &cypherModel.NodePattern{ + Binding: cypherModel.NewVariableWithSymbol(EdgeStartSymbol), Kinds: kinds, Properties: properties, } } -func EndNodePattern(kinds graph.Kinds, properties *model.Parameter) *model.NodePattern { - return &model.NodePattern{ - Binding: RelationshipEndSymbol, +func EndNodePattern(kinds graph.Kinds, properties *cypherModel.Parameter) *cypherModel.NodePattern { + return &cypherModel.NodePattern{ + Binding: cypherModel.NewVariableWithSymbol(EdgeEndSymbol), Kinds: kinds, Properties: properties, } } -func RelationshipPattern(kind graph.Kind, properties *model.Parameter, direction graph.Direction) *model.RelationshipPattern { - return &model.RelationshipPattern{ - Binding: RelationshipSymbol, +func RelationshipPattern(kind graph.Kind, properties *cypherModel.Parameter, direction graph.Direction) *cypherModel.RelationshipPattern { + return &cypherModel.RelationshipPattern{ + Binding: cypherModel.NewVariableWithSymbol(EdgeSymbol), Kinds: graph.Kinds{kind}, Properties: properties, Direction: direction, } } -func Create(elements ...graph.Criteria) *model.UpdatingClause { +func Create(elements ...graph.Criteria) *cypherModel.UpdatingClause { var ( - pattern = &model.PatternPart{} - createClause = &model.Create{ + pattern = &cypherModel.PatternPart{} + createClause = &cypherModel.Create{ // Note: Unique is Neo4j specific and will not be supported here. Use of constraints for // uniqueness is expected instead. Unique: false, - Pattern: []*model.PatternPart{pattern}, + Pattern: []*cypherModel.PatternPart{pattern}, } ) for _, element := range elements { switch typedElement := element.(type) { - case *model.Variable: + case *cypherModel.Variable: switch typedElement.Symbol { - case NodeSymbol, RelationshipStartSymbol, RelationshipEndSymbol: - pattern.AddPatternElements(&model.NodePattern{ - Binding: typedElement.Symbol, + case NodeSymbol, EdgeStartSymbol, EdgeEndSymbol: + pattern.AddPatternElements(&cypherModel.NodePattern{ + Binding: cypherModel.NewVariableWithSymbol(typedElement.Symbol), }) default: createClause.AddError(fmt.Errorf("invalid variable reference create: %s", typedElement.Symbol)) } - case *model.NodePattern: + case *cypherModel.NodePattern: pattern.AddPatternElements(typedElement) - case *model.RelationshipPattern: + case *cypherModel.RelationshipPattern: pattern.AddPatternElements(typedElement) default: @@ -483,50 +492,57 @@ func Create(elements ...graph.Criteria) *model.UpdatingClause { } } - return model.NewUpdatingClause(createClause) + return cypherModel.NewUpdatingClause(createClause) +} + +func ReturningDistinct(elements ...graph.Criteria) *cypherModel.Return { + returnCriteria := Returning(elements...) + returnCriteria.Projection.Distinct = true + + return returnCriteria } -func Returning(elements ...graph.Criteria) *model.Return { - projection := &model.Projection{} +func Returning(elements ...graph.Criteria) *cypherModel.Return { + projection := &cypherModel.Projection{} for _, element := range elements { switch typedElement := element.(type) { - case *model.Order: + case *cypherModel.Order: projection.Order = typedElement - case *model.Limit: + case *cypherModel.Limit: projection.Limit = typedElement - case *model.Skip: + case *cypherModel.Skip: projection.Skip = typedElement default: - projection.Items = append(projection.Items, &model.ProjectionItem{ + projection.Items = append(projection.Items, &cypherModel.ProjectionItem{ Expression: element, }) } } - return &model.Return{ + return &cypherModel.Return{ Projection: projection, } } -func Not(expression graph.Criteria) *model.Negation { - return &model.Negation{ +func Not(expression graph.Criteria) *cypherModel.Negation { + return &cypherModel.Negation{ Expression: expression, } } -func IsNull(reference graph.Criteria) *model.Comparison { - return model.NewComparison(reference, model.OperatorIs, Literal(nil)) +func IsNull(reference graph.Criteria) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorIs, Literal(nil)) } -func IsNotNull(reference graph.Criteria) *model.Comparison { - return model.NewComparison(reference, model.OperatorIsNot, Literal(nil)) +func IsNotNull(reference graph.Criteria) *cypherModel.Comparison { + return cypherModel.NewComparison(reference, cypherModel.OperatorIsNot, Literal(nil)) } -func GetFirstReadingClause(query *model.RegularQuery) *model.ReadingClause { +func GetFirstReadingClause(query *cypherModel.RegularQuery) *cypherModel.ReadingClause { if query.SingleQuery != nil && query.SingleQuery.SinglePartQuery != nil { readingClauses := query.SingleQuery.SinglePartQuery.ReadingClauses @@ -538,11 +554,11 @@ func GetFirstReadingClause(query *model.RegularQuery) *model.ReadingClause { return nil } -func SinglePartQuery(expressions ...graph.Criteria) *model.RegularQuery { +func SinglePartQuery(expressions ...graph.Criteria) *cypherModel.RegularQuery { var ( - singlePartQuery = &model.SinglePartQuery{} - query = &model.RegularQuery{ - SingleQuery: &model.SingleQuery{ + singlePartQuery = &cypherModel.SinglePartQuery{} + query = &cypherModel.RegularQuery{ + SingleQuery: &cypherModel.SingleQuery{ SinglePartQuery: singlePartQuery, }, } @@ -550,40 +566,40 @@ func SinglePartQuery(expressions ...graph.Criteria) *model.RegularQuery { for _, expression := range expressions { switch typedExpression := expression.(type) { - case *model.Where: + case *cypherModel.Where: if firstReadingClause := GetFirstReadingClause(query); firstReadingClause != nil { firstReadingClause.Match.Where = typedExpression } else { - singlePartQuery.AddReadingClause(&model.ReadingClause{ - Match: &model.Match{ + singlePartQuery.AddReadingClause(&cypherModel.ReadingClause{ + Match: &cypherModel.Match{ Where: typedExpression, }, Unwind: nil, }) } - case *model.Return: + case *cypherModel.Return: singlePartQuery.Return = typedExpression - case *model.Limit: + case *cypherModel.Limit: if singlePartQuery.Return != nil { singlePartQuery.Return.Projection.Limit = typedExpression } - case *model.Skip: + case *cypherModel.Skip: if singlePartQuery.Return != nil { singlePartQuery.Return.Projection.Skip = typedExpression } - case *model.Order: + case *cypherModel.Order: if singlePartQuery.Return != nil { singlePartQuery.Return.Projection.Order = typedExpression } - case *model.UpdatingClause: + case *cypherModel.UpdatingClause: singlePartQuery.AddUpdatingClause(typedExpression) - case []*model.UpdatingClause: + case []*cypherModel.UpdatingClause: for _, updatingClause := range typedExpression { singlePartQuery.AddUpdatingClause(updatingClause) } @@ -595,3 +611,11 @@ func SinglePartQuery(expressions ...graph.Criteria) *model.RegularQuery { return query } + +func EmptySinglePartQuery() *cypherModel.RegularQuery { + return &cypherModel.RegularQuery{ + SingleQuery: &cypherModel.SingleQuery{ + SinglePartQuery: &cypherModel.SinglePartQuery{}, + }, + } +} diff --git a/packages/go/dawgs/query/neo4j/neo4j.go b/packages/go/dawgs/query/neo4j/neo4j.go index e16a707381..af2bce8a8d 100644 --- a/packages/go/dawgs/query/neo4j/neo4j.go +++ b/packages/go/dawgs/query/neo4j/neo4j.go @@ -20,8 +20,7 @@ import ( "bytes" "errors" "fmt" - - "github.com/specterops/bloodhound/cypher/frontend" + "github.com/specterops/bloodhound/cypher/backend/cypher" "github.com/specterops/bloodhound/cypher/model" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/dawgs/query" @@ -35,6 +34,7 @@ type QueryBuilder struct { Parameters map[string]any query *model.RegularQuery + order *model.Order relationshipPatternKinds graph.Kinds prepared bool } @@ -55,10 +55,10 @@ func NewEmptyQueryBuilder() *QueryBuilder { } } -func (s *QueryBuilder) liftRelationshipKindMatchers() func(parent, element any) error { +func (s *QueryBuilder) liftRelationshipKindMatchers() model.Visitor { firstReadingClause := query.GetFirstReadingClause(s.query) - return func(parent, element any) error { + return model.NewVisitor(func(stack *model.WalkStack, element model.Expression) error { if firstReadingClause == nil { return nil } @@ -79,7 +79,7 @@ func (s *QueryBuilder) liftRelationshipKindMatchers() func(parent, element any) switch variable := typedExpression.Reference.(type) { case *model.Variable: switch variable.Symbol { - case query.RelationshipSymbol: + case query.EdgeSymbol: firstRelationshipPattern.Kinds = append(firstRelationshipPattern.Kinds, typedExpression.Kinds...) removeList = append(removeList, expression) } @@ -93,13 +93,13 @@ func (s *QueryBuilder) liftRelationshipKindMatchers() func(parent, element any) } return nil - } + }, nil) } func (s *QueryBuilder) rewriteParameters() error { parameterRewriter := query.NewParameterRewriter() - if err := model.Walk(s.query, parameterRewriter.Visit, nil); err != nil { + if err := model.Walk(s.query, model.NewVisitor(parameterRewriter.Visit, nil)); err != nil { return err } @@ -119,7 +119,7 @@ func (s *QueryBuilder) Apply(criteria graph.Criteria) { query.GetFirstReadingClause(s.query).Match.Where = model.Copy(typedCriteria) case *model.Return: - s.query.SingleQuery.SinglePartQuery.Return = typedCriteria + s.query.SingleQuery.SinglePartQuery.Return = model.Copy(typedCriteria) case *model.Limit: if s.query.SingleQuery.SinglePartQuery.Return != nil { @@ -132,9 +132,7 @@ func (s *QueryBuilder) Apply(criteria graph.Criteria) { } case *model.Order: - if s.query.SingleQuery.SinglePartQuery.Return != nil { - s.query.SingleQuery.SinglePartQuery.Return.Projection.Order = model.Copy(typedCriteria) - } + s.order = model.Copy(typedCriteria) case []*model.UpdatingClause: for _, updatingClause := range typedCriteria { @@ -165,58 +163,72 @@ func (s *QueryBuilder) prepareMatch() error { isRelationshipQuery = false - bindWalk = func(parent, element any) error { + bindWalk = model.NewVisitor(func(stack *model.WalkStack, element model.Expression) error { switch typedElement := element.(type) { case *model.Variable: switch typedElement.Symbol { case query.NodeSymbol: singleNodeBound = true - case query.RelationshipStartSymbol: + case query.EdgeStartSymbol: startNodeBound = true isRelationshipQuery = true - case query.RelationshipEndSymbol: + case query.EdgeEndSymbol: endNodeBound = true isRelationshipQuery = true - case query.RelationshipSymbol: + case query.EdgeSymbol: relationshipBound = true isRelationshipQuery = true } } return nil - } + }, nil) ) // Zip through updating clauses first - for _, updateClause := range s.query.SingleQuery.SinglePartQuery.UpdatingClauses { - switch typedClause := updateClause.Clause.(type) { + for _, updatingClause := range s.query.SingleQuery.SinglePartQuery.UpdatingClauses { + typedUpdatingClause, typeOK := updatingClause.(*model.UpdatingClause) + + if !typeOK { + return fmt.Errorf("unexpected updating clause type %T", typedUpdatingClause) + } + + switch typedClause := typedUpdatingClause.Clause.(type) { case *model.Create: - if err := model.Walk(typedClause, func(parent, element any) error { + if err := model.Walk(typedClause, model.NewVisitor(func(stack *model.WalkStack, element model.Expression) error { switch typedElement := element.(type) { case *model.NodePattern: - switch typedElement.Binding { - case query.NodeSymbol: - creatingSingleNode = true - - case query.RelationshipStartSymbol: - creatingStartNode = true - - case query.RelationshipEndSymbol: - creatingEndNode = true + if typedBinding, isVariable := typedElement.Binding.(*model.Variable); !isVariable { + return fmt.Errorf("expected variable but got %T", typedElement.Binding) + } else { + switch typedBinding.Symbol { + case query.NodeSymbol: + creatingSingleNode = true + + case query.EdgeStartSymbol: + creatingStartNode = true + + case query.EdgeEndSymbol: + creatingEndNode = true + } } case *model.RelationshipPattern: - switch typedElement.Binding { - case query.RelationshipSymbol: - creatingRelationship = true + if typedBinding, isVariable := typedElement.Binding.(*model.Variable); !isVariable { + return fmt.Errorf("expected variable but got %T", typedElement.Binding) + } else { + switch typedBinding.Symbol { + case query.EdgeSymbol: + creatingRelationship = true + } } } return nil - }, nil); err != nil { + }, nil)); err != nil { return err } @@ -235,7 +247,16 @@ func (s *QueryBuilder) prepareMatch() error { } // Is there a return clause - if s.query.SingleQuery.SinglePartQuery.Return != nil { + if spqReturn := s.query.SingleQuery.SinglePartQuery.Return; spqReturn != nil && spqReturn.Projection != nil { + // Did we have an order specified? + if s.order != nil { + if spqReturn.Projection.Order != nil { + return fmt.Errorf("order specified twice") + } + + s.query.SingleQuery.SinglePartQuery.Return.Projection.Order = s.order + } + if err := model.Walk(s.query.SingleQuery.SinglePartQuery.Return, bindWalk, nil); err != nil { return err } @@ -248,13 +269,13 @@ func (s *QueryBuilder) prepareMatch() error { if singleNodeBound && !creatingSingleNode { patternPart.AddPatternElements(&model.NodePattern{ - Binding: query.NodeSymbol, + Binding: model.NewVariableWithSymbol(query.NodeSymbol), }) } if startNodeBound { patternPart.AddPatternElements(&model.NodePattern{ - Binding: query.RelationshipStartSymbol, + Binding: model.NewVariableWithSymbol(query.EdgeStartSymbol), }) } @@ -266,7 +287,7 @@ func (s *QueryBuilder) prepareMatch() error { if !creatingRelationship { if relationshipBound { patternPart.AddPatternElements(&model.RelationshipPattern{ - Binding: query.RelationshipSymbol, + Binding: model.NewVariableWithSymbol(query.EdgeSymbol), Direction: graph.DirectionOutbound, }) } else { @@ -283,7 +304,7 @@ func (s *QueryBuilder) prepareMatch() error { if endNodeBound { patternPart.AddPatternElements(&model.NodePattern{ - Binding: query.RelationshipEndSymbol, + Binding: model.NewVariableWithSymbol(query.EdgeEndSymbol), }) } @@ -305,7 +326,7 @@ func (s *QueryBuilder) prepareMatch() error { func (s *QueryBuilder) compilationErrors() error { var modelErrors []error - model.Walk(s.query, func(parent, element any) error { + model.Walk(s.query, model.NewVisitor(func(stack *model.WalkStack, element model.Expression) error { if errorNode, typeOK := element.(model.Fallible); typeOK { if len(errorNode.Errors()) > 0 { modelErrors = append(modelErrors, errorNode.Errors()...) @@ -313,7 +334,7 @@ func (s *QueryBuilder) compilationErrors() error { } return nil - }, nil) + }, nil)) return errors.Join(modelErrors...) } @@ -341,11 +362,15 @@ func (s *QueryBuilder) Prepare() error { return err } - if err := model.Walk(s.query, StringNegationRewriter, nil); err != nil { + if err := model.Walk(s.query, model.NewVisitor(StringNegationRewriter, nil)); err != nil { + return err + } + + if err := model.Walk(s.query, s.liftRelationshipKindMatchers()); err != nil { return err } - return model.Walk(s.query, s.liftRelationshipKindMatchers(), RemoveEmptyExpressionLists) + return model.Walk(s.query, model.NewVisitor(nil, RemoveEmptyExpressionLists)) } func (s *QueryBuilder) PrepareAllShortestPaths() error { @@ -363,7 +388,7 @@ func (s *QueryBuilder) PrepareAllShortestPaths() error { patternPart := firstReadingClause.Match.Pattern[0] // Bind the path - patternPart.Binding = query.PathSymbol + patternPart.Binding = model.NewVariableWithSymbol(query.PathSymbol) // Set the pattern to search for all shortest paths patternPart.AllShortestPathsPattern = true @@ -382,7 +407,7 @@ func (s *QueryBuilder) PrepareAllShortestPaths() error { func (s *QueryBuilder) Render() (string, error) { buffer := &bytes.Buffer{} - if err := frontend.NewCypherEmitter(false).Write(s.query, buffer); err != nil { + if err := cypher.NewCypherEmitter(false).Write(s.query, buffer); err != nil { return "", err } else { return buffer.String(), nil diff --git a/packages/go/dawgs/query/neo4j/neo4j_test.go b/packages/go/dawgs/query/neo4j/neo4j_test.go index 2b81c57802..5cce0744c8 100644 --- a/packages/go/dawgs/query/neo4j/neo4j_test.go +++ b/packages/go/dawgs/query/neo4j/neo4j_test.go @@ -140,9 +140,9 @@ func TestQueryBuilder_RenderShortestPaths(t *testing.T) { query.Returning( query.Path(), ), - ), "match p = allShortestPaths((s)-[*]->(e)) where s.objectid = $0 and (s:A or s:B) and e.objectid = $1 and e:B return p", map[string]any{ - "0": "12345", - "1": "56789", + ), "match p = allShortestPaths((s)-[*]->(e)) where s.objectid = $p0 and (s:A or s:B) and e.objectid = $p1 and e:B return p", map[string]any{ + "p0": "12345", + "p1": "56789", })) t.Run("Shortest Paths with Bound Relationship", assertQueryShortestPathResult(query.SinglePartQuery( @@ -159,9 +159,9 @@ func TestQueryBuilder_RenderShortestPaths(t *testing.T) { query.Returning( query.Path(), ), - ), "match p = allShortestPaths((s)-[r:R1|R2*]->(e)) where s.objectid = $0 and (s:A or s:B) and e.objectid = $1 and e:B return p", map[string]any{ - "0": "12345", - "1": "56789", + ), "match p = allShortestPaths((s)-[r:R1|R2*]->(e)) where s.objectid = $p0 and (s:A or s:B) and e.objectid = $p1 and e:B return p", map[string]any{ + "p0": "12345", + "p1": "56789", })) } @@ -178,11 +178,11 @@ func TestQueryBuilder_Render(t *testing.T) { query.Limit(10), query.Offset(20), - ), "match (n) where id(n) in $0 return count(n) skip 20 limit 10", map[string]any{ - "0": []graph.ID{1, 2, 3, 4}, + ), "match (n) where id(n) in $p0 return count(n) skip 20 limit 10", map[string]any{ + "p0": []graph.ID{1, 2, 3, 4}, })) - t.Run("Node Property", assertQueryResult(query.SinglePartQuery( + t.Run("Node Item", assertQueryResult(query.SinglePartQuery( query.Where( query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), ), @@ -190,13 +190,13 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Count(query.Node()), ), - ), "match (n) where n.prop in $0 return count(n)")) + ), "match (n) where n.prop in $p0 return count(n)")) // TODO: Revisit parameter reuse // //reusedLiteral := query3.Literal([]int{1, 2, 3, 4}) // - //t.Run("Node Property with Reused Literal", assertQueryResult(query3.Query( + //t.Run("Node Item with Reused Literal", assertQueryResult(query3.Query( // query3.Where( // query3.And( // query3.In(query3.NodeProperty("prop"), reusedLiteral), @@ -207,27 +207,27 @@ func TestQueryBuilder_Render(t *testing.T) { // query3.Returning( // query3.Count(query3.Node()), // ), - //), "match (n) where n.prop in $0 and n.other_prop in $0 return count(n)")) + //), "match (n) where n.prop in $p0 and n.other_prop in $p0 return count(n)")) - t.Run("Distinct Property", assertQueryResult(query.SinglePartQuery( + t.Run("Distinct Item", assertQueryResult(query.SinglePartQuery( query.Where( query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), ), - query.Returning( - query.Distinct(query.NodeProperty("prop")), + query.ReturningDistinct( + query.NodeProperty("prop"), ), - ), "match (n) where n.prop in $0 return distinct(n.prop)")) + ), "match (n) where n.prop in $p0 return distinct n.prop")) - t.Run("Count Distinct Property", assertQueryResult(query.SinglePartQuery( + t.Run("Count Distinct Item", assertQueryResult(query.SinglePartQuery( query.Where( query.In(query.NodeProperty("prop"), []int{1, 2, 3, 4}), ), query.Returning( - query.Count(query.Distinct(query.NodeProperty("prop"))), + query.CountDistinct(query.NodeProperty("prop")), ), - ), "match (n) where n.prop in $0 return count(distinct(n.prop))")) + ), "match (n) where n.prop in $p0 return count(distinct n.prop)")) t.Run("Set Node Labels", assertQueryResult(query.SinglePartQuery( query.Where( @@ -242,7 +242,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Count(query.Node()), ), - ), "match (n) where n.prop in $0 set n:Domain set n:User return count(n)")) + ), "match (n) where n.prop in $p0 set n:Domain set n:User return count(n)")) t.Run("Remove Node Labels", assertQueryResult(query.SinglePartQuery( query.Where( @@ -257,7 +257,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Count(query.Node()), ), - ), "match (n) where n.prop in $0 remove n:Domain remove n:User return count(n)")) + ), "match (n) where n.prop in $p0 remove n:Domain remove n:User return count(n)")) t.Run("Multiple Node ID References", assertQueryResult(query.SinglePartQuery( query.Where( @@ -274,7 +274,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Limit(10), query.Offset(20), - ), "match (n) where n.name = $0 and id(n) in $1 return id(n), n.value skip 20 limit 10")) + ), "match (n) where n.name = $p0 and id(n) in $p1 return id(n), n.value skip 20 limit 10")) // Create node t.Run("Create Node", assertQueryResult(query.SinglePartQuery( @@ -291,9 +291,9 @@ func TestQueryBuilder_Render(t *testing.T) { query.Identity(query.Node()), ), ), - "create (n:Domain:Computer $0) return id(n)", + "create (n:Domain:Computer $p0) return id(n)", map[string]any{ - "0": map[string]any{ + "p0": map[string]any{ "prop1": 1234, }, }, @@ -321,7 +321,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Limit(10), query.Offset(20), - ), "match (n) where n.name = $0 and id(n) in $1 remove n.other remove n.other2 return id(n), n.value skip 20 limit 10")) + ), "match (n) where n.name = $p0 and id(n) in $p1 remove n.other remove n.other2 return id(n), n.value skip 20 limit 10")) properties := graph.NewProperties() properties.Set("test_1", "value_1") @@ -337,19 +337,19 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), []QueryOutputAssertion{ { - Query: "match (n) where n.objectid = $0 set n.test_1 = $1, n.test_2 = $2", + Query: "match (n) where n.objectid = $p0 set n.test_1 = $p1, n.test_2 = $p2", Parameters: map[string]any{ - "0": "12345", - "1": "value_1", - "2": "value_2", + "p0": "12345", + "p1": "value_1", + "p2": "value_2", }, }, { - Query: "match (n) where n.objectid = $0 set n.test_2 = $1, n.test_1 = $2", + Query: "match (n) where n.objectid = $p0 set n.test_2 = $p1, n.test_1 = $p2", Parameters: map[string]any{ - "0": "12345", - "1": "value_2", - "2": "value_1", + "p0": "12345", + "p1": "value_2", + "p2": "value_1", }, }, })) @@ -367,15 +367,15 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), []QueryOutputAssertion{ { - Query: "match (n) where n.objectid = $0 remove n.test_2, n.test_1", + Query: "match (n) where n.objectid = $p0 remove n.test_2, n.test_1", Parameters: map[string]any{ - "0": "12345", + "p0": "12345", }, }, { - Query: "match (n) where n.objectid = $0 remove n.test_1, n.test_2", + Query: "match (n) where n.objectid = $p0 remove n.test_1, n.test_2", Parameters: map[string]any{ - "0": "12345", + "p0": "12345", }, }, })) @@ -399,7 +399,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Limit(10), query.Offset(20), - ), "match (n) where n.name = $0 and id(n) in $1 set n.other = $2 return id(n), n.value skip 20 limit 10")) + ), "match (n) where n.name = $p0 and id(n) in $p1 set n.other = $p2 return id(n), n.value skip 20 limit 10")) updatedNode := graph.NewNode(graph.ID(1), graph.NewProperties(), User, Domain, Computer) updatedNode.Properties.Set("test_1", "value_1") @@ -428,7 +428,7 @@ func TestQueryBuilder_Render(t *testing.T) { return updateStatements }), - ), "match (n) where id(n) = $0 set n:User:Domain:Computer set n.test_1 = $1 remove n.test_2")) + ), "match (n) where id(n) = $p0 set n:User:Domain:Computer set n.test_1 = $p1 remove n.test_2")) t.Run("Node has Relationships", assertQueryResult(query.SinglePartQuery( query.Where( @@ -440,7 +440,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), "match (n) where (n)<-[]->() return n")) - t.Run("Node has Relationships Order by Node Property", assertQueryResult(query.SinglePartQuery( + t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( query.Where( query.HasRelationships(query.Node()), ), @@ -454,7 +454,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), "match (n) where (n)<-[]->() return n order by n.value asc")) - t.Run("Node has Relationships Order by Node Property", assertQueryResult(query.SinglePartQuery( + t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( query.Where( query.HasRelationships(query.Node()), ), @@ -469,7 +469,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc")) - t.Run("Node has Relationships Order by Node Property with Limit and Offset", assertQueryResult(query.SinglePartQuery( + t.Run("Node has Relationships Order by Node Item with Limit and Offset", assertQueryResult(query.SinglePartQuery( query.Where( query.HasRelationships(query.Node()), ), @@ -508,7 +508,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where n.lastseen < $0 and id(n) in $1 return n")) + ), "match (n) where n.lastseen < $p0 and id(n) in $p1 return n")) t.Run("Node Datetime Before or Equal to", assertQueryResult(query.SinglePartQuery( query.Where( @@ -521,7 +521,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where n.lastseen <= $0 and id(n) in $1 return n")) + ), "match (n) where n.lastseen <= $p0 and id(n) in $p1 return n")) t.Run("Node Datetime After", assertQueryResult(query.SinglePartQuery( query.Where( @@ -534,7 +534,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where n.lastseen > $0 and id(n) in $1 return n")) + ), "match (n) where n.lastseen > $p0 and id(n) in $p1 return n")) t.Run("Node Datetime After or Equal to", assertQueryResult(query.SinglePartQuery( query.Where( @@ -547,7 +547,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where n.lastseen >= $0 and id(n) in $1 return n")) + ), "match (n) where n.lastseen >= $p0 and id(n) in $p1 return n")) t.Run("Node PropertyExists", assertQueryResult(query.SinglePartQuery( query.Where( @@ -560,9 +560,9 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where exists(n.lastseen) and id(n) in $0 return n")) + ), "match (n) where n.lastseen is not null and id(n) in $p0 return n")) - t.Run("Return Node Kinds", assertQueryResult(query.SinglePartQuery( + t.Run("Select Node Kinds", assertQueryResult(query.SinglePartQuery( query.Where( query.And( query.Kind(query.Node(), Domain), @@ -574,7 +574,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), "match (n) where n:Domain return labels(n)")) - t.Run("Return Node ID and Kinds", assertQueryResult(query.SinglePartQuery( + t.Run("Select Node ID and Kinds", assertQueryResult(query.SinglePartQuery( query.Where( query.And( query.Kind(query.Node(), Domain), @@ -611,7 +611,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), "match (n) where (n:Domain or n:User or n:Group) return n")) - t.Run("Node String Property Contains", assertQueryResult(query.SinglePartQuery( + t.Run("Node String Item Contains", assertQueryResult(query.SinglePartQuery( query.Where( query.StringContains(query.NodeProperty("tags"), "tag_1"), ), @@ -619,9 +619,9 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where n.tags contains $0 return n")) + ), "match (n) where n.tags contains $p0 return n")) - t.Run("Node String Property Starts With", assertQueryResult(query.SinglePartQuery( + t.Run("Node String Item Starts With", assertQueryResult(query.SinglePartQuery( query.Where( query.StringStartsWith(query.NodeProperty("tags"), "tag_1"), ), @@ -629,9 +629,9 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where n.tags starts with $0 return n")) + ), "match (n) where n.tags starts with $p0 return n")) - t.Run("Node String Property Ends With", assertQueryResult(query.SinglePartQuery( + t.Run("Node String Item Ends With", assertQueryResult(query.SinglePartQuery( query.Where( query.StringEndsWith(query.NodeProperty("tags"), "tag_1"), ), @@ -639,9 +639,9 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where n.tags ends with $0 return n")) + ), "match (n) where n.tags ends with $p0 return n")) - t.Run("Node String Property Case Insensitive Contains", assertQueryResult(query.SinglePartQuery( + t.Run("Node String Item Case Insensitive Contains", assertQueryResult(query.SinglePartQuery( query.Where( query.CaseInsensitiveStringContains(query.NodeProperty("tags"), "tag_1"), ), @@ -649,9 +649,9 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where toLower(n.tags) contains $0 return n")) + ), "match (n) where toLower(n.tags) contains $p0 return n")) - t.Run("Node String Property Case Insensitive Starts With", assertQueryResult(query.SinglePartQuery( + t.Run("Node String Item Case Insensitive Starts With", assertQueryResult(query.SinglePartQuery( query.Where( query.CaseInsensitiveStringStartsWith(query.NodeProperty("tags"), "tag_1"), ), @@ -659,9 +659,9 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where toLower(n.tags) starts with $0 return n")) + ), "match (n) where toLower(n.tags) starts with $p0 return n")) - t.Run("Node String Property Case Insensitive Ends With", assertQueryResult(query.SinglePartQuery( + t.Run("Node String Item Case Insensitive Ends With", assertQueryResult(query.SinglePartQuery( query.Where( query.CaseInsensitiveStringEndsWith(query.NodeProperty("tags"), "tag_1"), ), @@ -669,7 +669,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where toLower(n.tags) ends with $0 return n")) + ), "match (n) where toLower(n.tags) ends with $p0 return n")) t.Run("Node Delete", assertQueryResult(query.SinglePartQuery( query.Where( @@ -679,7 +679,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Delete( query.Node(), ), - ), "match (n) where n in $0 detach delete n")) + ), "match (n) where n in $p0 detach delete n")) // Relationship Queries t.Run("Empty Relationship Query", assertQueryResult(query.SinglePartQuery( @@ -732,7 +732,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), ), "match (s)-[r]->(e) return id(r), type(r), labels(s), labels(e)")) - t.Run("Relationship Property and ID References", assertQueryResult(query.SinglePartQuery( + t.Run("Relationship Item and ID References", assertQueryResult(query.SinglePartQuery( query.Where( query.And( query.Equals(query.RelationshipProperty("name"), "name"), @@ -746,9 +746,9 @@ func TestQueryBuilder_Render(t *testing.T) { ), query.Offset(20), - ), "match ()-[r]->() where r.name = $0 and id(r) in $1 return id(r), r.value skip 20")) + ), "match ()-[r]->() where r.name = $p0 and id(r) in $p1 return id(r), r.value skip 20")) - t.Run("Relationship Return Start References", assertQueryResult(query.SinglePartQuery( + t.Run("Relationship Select Start References", assertQueryResult(query.SinglePartQuery( query.Where( query.And( query.Equals(query.RelationshipProperty("name"), "name"), @@ -762,7 +762,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), query.Offset(20), - ), "match (s)-[r]->() where r.name = $0 and id(r) in $1 return id(s), r.value skip 20")) + ), "match (s)-[r]->() where r.name = $p0 and id(r) in $p1 return id(s), r.value skip 20")) t.Run("Relationship Start Node ID Reference", assertQueryResult(query.SinglePartQuery( query.Where( @@ -779,7 +779,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), query.Offset(20), - ), "match (s)-[r]->() where id(s) = $0 and r.name = $1 and id(r) in $2 return id(r), r.value skip 20")) + ), "match (s)-[r]->() where id(s) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value skip 20")) t.Run("Relationship End Node ID Reference", assertQueryResult(query.SinglePartQuery( query.Where( @@ -796,7 +796,7 @@ func TestQueryBuilder_Render(t *testing.T) { ), query.Offset(20), - ), "match ()-[r]->(e) where id(e) = $0 and r.name = $1 and id(r) in $2 return id(r), r.value skip 20")) + ), "match ()-[r]->(e) where id(e) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value skip 20")) t.Run("Relationship Start and End Node ID References", assertQueryResult(query.SinglePartQuery( query.Where( @@ -812,7 +812,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.RelationshipID(), query.Property(query.Relationship(), "value"), ), - ), "match (s)-[r]->(e) where id(s) = $0 and id(e) = $1 and r.name = $2 and id(r) in $3 return id(r), r.value")) + ), "match (s)-[r]->(e) where id(s) = $p0 and id(e) = $p1 and r.name = $p2 and id(r) in $p3 return id(r), r.value")) t.Run("Relationship Kind Match without Joining Expression", assertQueryResult(query.SinglePartQuery( query.Where( @@ -839,7 +839,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.RelationshipID(), query.Property(query.Relationship(), "value"), ), - ), "match (s)-[r:HasSession]->() where id(s) = $0 and r.name = $1 and id(r) in $2 return id(r), r.value")) + ), "match (s)-[r:HasSession]->() where id(s) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value")) updatedRelationship := graph.NewRelationship(graph.ID(1), graph.ID(2), graph.ID(3), graph.NewProperties(), HasSession) updatedRelationship.Properties.Set("test_1", "value_1") @@ -866,7 +866,7 @@ func TestQueryBuilder_Render(t *testing.T) { return updateStatements }), - ), "match ()-[r]->() where id(r) = $0 set r.test_1 = $1 remove r.test_2")) + ), "match ()-[r]->() where id(r) = $p0 set r.test_1 = $p1 remove r.test_2")) t.Run("Relationship Kind Match in", assertQueryResult(query.SinglePartQuery( query.Where( @@ -882,7 +882,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.RelationshipID(), query.Property(query.Relationship(), "value"), ), - ), "match (s)-[r:HasSession|GenericWrite]->() where id(s) = $0 and r.name = $1 and id(r) in $2 return id(r), r.value")) + ), "match (s)-[r:HasSession|GenericWrite]->() where id(s) = $p0 and r.name = $p1 and id(r) in $p2 return id(r), r.value")) t.Run("Relationship Kind Match in and Start Node Kind Match in", assertQueryResult(query.SinglePartQuery( query.Where( @@ -898,7 +898,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.RelationshipID(), query.Property(query.Relationship(), "value"), ), - ), "match (s)-[r:HasSession|GenericWrite]->() where (s:User or s:Computer) and r.name = $0 and id(r) in $1 return id(r), r.value")) + ), "match (s)-[r:HasSession|GenericWrite]->() where (s:User or s:Computer) and r.name = $p0 and id(r) in $p1 return id(r), r.value")) t.Run("Relationship Kind Match in and Delete Start Node and Relationship", assertQueryResult(query.SinglePartQuery( query.Where( @@ -957,15 +957,15 @@ func TestQueryBuilder_Render(t *testing.T) { query.Identity(query.Relationship()), ), ), - "create (s:Computer $0)-[r:HasSession $1]->(e:User $2) return id(r)", + "create (s:Computer $p0)-[r:HasSession $p1]->(e:User $p2) return id(r)", map[string]any{ - "0": map[string]any{ + "p0": map[string]any{ "prop1": 1234, }, - "1": map[string]any{ + "p1": map[string]any{ "prop1": 1234, }, - "2": map[string]any{ + "p2": map[string]any{ "prop1": 1234, }, }, @@ -995,11 +995,11 @@ func TestQueryBuilder_Render(t *testing.T) { query.Identity(query.Relationship()), ), ), - "match (s), (e) where id(s) = $0 and id(e) = $1 create (s)-[r:HasSession $2]->(e) return id(r)", + "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:HasSession $p2]->(e) return id(r)", map[string]any{ - "0": 1, - "1": 2, - "2": map[string]any{ + "p0": 1, + "p1": 2, + "p2": map[string]any{ "prop1": 1234, }, }, @@ -1015,7 +1015,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Count(query.Node()), ), - ), "match (n) where (not (n.system_tags contains $0) or n.system_tags is null) return count(n)")) + ), "match (n) where (not (n.system_tags contains $p0) or n.system_tags is null) return count(n)")) t.Run("Is Not Null", assertQueryResult(query.SinglePartQuery( query.Where( diff --git a/packages/go/dawgs/query/neo4j/rewrite.go b/packages/go/dawgs/query/neo4j/rewrite.go index 1f25462b94..f770e4416d 100644 --- a/packages/go/dawgs/query/neo4j/rewrite.go +++ b/packages/go/dawgs/query/neo4j/rewrite.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package neo4j @@ -23,7 +23,7 @@ import ( "github.com/specterops/bloodhound/dawgs/query" ) -func RemoveEmptyExpressionLists(parent, element any) error { +func RemoveEmptyExpressionLists(stack *model.WalkStack, element model.Expression) error { var ( shouldRemove = false shouldReplace = false @@ -52,12 +52,12 @@ func RemoveEmptyExpressionLists(parent, element any) error { } if shouldRemove { - switch typedParent := parent.(type) { + switch typedParent := stack.Trunk().(type) { case model.ExpressionList: typedParent.Remove(element) } } else if shouldReplace { - switch typedParent := parent.(type) { + switch typedParent := stack.Trunk().(type) { case model.ExpressionList: typedParent.Replace(typedParent.IndexOf(element), replacementExpression) } @@ -66,7 +66,7 @@ func RemoveEmptyExpressionLists(parent, element any) error { return nil } -func StringNegationRewriter(parent, element any) error { +func StringNegationRewriter(stack *model.WalkStack, element model.Expression) error { var rewritten any switch negation := element.(type) { @@ -94,7 +94,7 @@ func StringNegationRewriter(parent, element any) error { // If we rewrote this element, replace it if rewritten != nil { - switch typedParent := parent.(type) { + switch typedParent := stack.Trunk().(type) { case model.ExpressionList: for idx, expression := range typedParent.GetAll() { if expression == element { @@ -104,7 +104,7 @@ func StringNegationRewriter(parent, element any) error { } default: - return fmt.Errorf("unable to replace rewritten string negation operation for parent type %T", parent) + return fmt.Errorf("unable to replace rewritten string negation operation for parent type %T", stack.Trunk()) } } diff --git a/packages/go/dawgs/query/rewrite.go b/packages/go/dawgs/query/rewrite.go index 29e4bd150f..3e6bcaf806 100644 --- a/packages/go/dawgs/query/rewrite.go +++ b/packages/go/dawgs/query/rewrite.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package query @@ -34,12 +34,12 @@ func NewParameterRewriter() *ParameterRewriter { } } -func (s *ParameterRewriter) Visit(parent, element any) error { +func (s *ParameterRewriter) Visit(stack *model.WalkStack, element model.Expression) error { switch typedElement := element.(type) { case *model.Parameter: var ( nextParameterIndex = s.parameterIndex - nextParameterIndexStr = strconv.Itoa(nextParameterIndex) + nextParameterIndexStr = "p" + strconv.Itoa(nextParameterIndex) ) // Increment the parameter index first diff --git a/packages/go/dawgs/traversal/id.go b/packages/go/dawgs/traversal/id.go index d1be391c84..bc0f4e2512 100644 --- a/packages/go/dawgs/traversal/id.go +++ b/packages/go/dawgs/traversal/id.go @@ -81,7 +81,7 @@ func (s IDTraversal) BreadthFirst(ctx context.Context, plan IDPlan) error { go func(workerID int) { defer workerWG.Done() - if err := s.db.ReadTransaction(traversalCtx, func(tx graph.Transaction) error { + if err := s.db.ReadTransaction(ctx, func(tx graph.Transaction) error { for { if nextDescent, ok := channels.Receive(traversalCtx, segmentReaderC); !ok { return nil @@ -101,12 +101,12 @@ func (s IDTraversal) BreadthFirst(ctx context.Context, plan IDPlan) error { errors.Add(fmt.Errorf("%w - Limit: %.2f MB - Memory In-Use: %.2f MB", ops.ErrTraversalMemoryLimit, tx.TraversalMemoryLimit().Mebibytes(), pathTree.SizeOf().Mebibytes())) } - // Mark descent for this segment as complete - descentCount.Add(-1) - if !channels.Submit(traversalCtx, completionC, struct{}{}) { return nil } + + // Mark descent for this segment as complete + descentCount.Add(-1) } }); err != nil && err != graph.ErrContextTimedOut { // A worker encountered a fatal error, kill the traversal context diff --git a/packages/go/dawgs/traversal/traversal.go b/packages/go/dawgs/traversal/traversal.go index f861b982a4..3ec0b72b44 100644 --- a/packages/go/dawgs/traversal/traversal.go +++ b/packages/go/dawgs/traversal/traversal.go @@ -208,6 +208,12 @@ type Plan struct { Driver Driver } +type Service struct { + db graph.Database + workerWG *sync.WaitGroup + numWorkers int +} + type Traversal struct { db graph.Database numWorkers int @@ -221,6 +227,8 @@ func New(db graph.Database, numParallelWorkers int) Traversal { } func (s Traversal) BreadthFirst(ctx context.Context, plan Plan) error { + defer log.Measure(log.LevelDebug, "BreadthFirst - %d workers", s.numWorkers)() + var ( // workerWG keeps count of background workers launched in goroutines workerWG = &sync.WaitGroup{} @@ -258,7 +266,7 @@ func (s Traversal) BreadthFirst(ctx context.Context, plan Plan) error { go func(workerID int) { defer workerWG.Done() - if err := s.db.ReadTransaction(traversalCtx, func(tx graph.Transaction) error { + if err := s.db.ReadTransaction(ctx, func(tx graph.Transaction) error { for { if nextDescent, ok := channels.Receive(traversalCtx, segmentReaderC); !ok { return nil @@ -360,7 +368,7 @@ func shallowFetchRelationships(direction graph.Direction, segment *graph.PathSeg return nil, fmt.Errorf("bi-directional or non-directed edges are not supported") } - if err := graphQuery.Execute(func(results graph.Result) error { + if err := graphQuery.Query(func(results graph.Result) error { defer results.Close() var ( diff --git a/packages/go/dawgs/vendormocks/jackc/pgx/v5/mock.go b/packages/go/dawgs/vendormocks/jackc/pgx/v5/mock.go new file mode 100644 index 0000000000..389566b068 --- /dev/null +++ b/packages/go/dawgs/vendormocks/jackc/pgx/v5/mock.go @@ -0,0 +1,227 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/jackc/pgx/v5 (interfaces: Tx) + +// Package pgx is a generated GoMock package. +package pgx + +import ( + context "context" + reflect "reflect" + + pgx "github.com/jackc/pgx/v5" + pgconn "github.com/jackc/pgx/v5/pgconn" + gomock "go.uber.org/mock/gomock" +) + +// MockTx is a mock of Tx interface. +type MockTx struct { + ctrl *gomock.Controller + recorder *MockTxMockRecorder +} + +// MockTxMockRecorder is the mock recorder for MockTx. +type MockTxMockRecorder struct { + mock *MockTx +} + +// NewMockTx creates a new mock instance. +func NewMockTx(ctrl *gomock.Controller) *MockTx { + mock := &MockTx{ctrl: ctrl} + mock.recorder = &MockTxMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTx) EXPECT() *MockTxMockRecorder { + return m.recorder +} + +// Begin mocks base method. +func (m *MockTx) Begin(arg0 context.Context) (pgx.Tx, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin", arg0) + ret0, _ := ret[0].(pgx.Tx) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin. +func (mr *MockTxMockRecorder) Begin(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTx)(nil).Begin), arg0) +} + +// Commit mocks base method. +func (m *MockTx) Commit(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTxMockRecorder) Commit(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTx)(nil).Commit), arg0) +} + +// Conn mocks base method. +func (m *MockTx) Conn() *pgx.Conn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Conn") + ret0, _ := ret[0].(*pgx.Conn) + return ret0 +} + +// Conn indicates an expected call of Conn. +func (mr *MockTxMockRecorder) Conn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Conn", reflect.TypeOf((*MockTx)(nil).Conn)) +} + +// CopyFrom mocks base method. +func (m *MockTx) CopyFrom(arg0 context.Context, arg1 pgx.Identifier, arg2 []string, arg3 pgx.CopyFromSource) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CopyFrom", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CopyFrom indicates an expected call of CopyFrom. +func (mr *MockTxMockRecorder) CopyFrom(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CopyFrom", reflect.TypeOf((*MockTx)(nil).CopyFrom), arg0, arg1, arg2, arg3) +} + +// Exec mocks base method. +func (m *MockTx) Exec(arg0 context.Context, arg1 string, arg2 ...interface{}) (pgconn.CommandTag, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(pgconn.CommandTag) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTxMockRecorder) Exec(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTx)(nil).Exec), varargs...) +} + +// LargeObjects mocks base method. +func (m *MockTx) LargeObjects() pgx.LargeObjects { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LargeObjects") + ret0, _ := ret[0].(pgx.LargeObjects) + return ret0 +} + +// LargeObjects indicates an expected call of LargeObjects. +func (mr *MockTxMockRecorder) LargeObjects() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LargeObjects", reflect.TypeOf((*MockTx)(nil).LargeObjects)) +} + +// Prepare mocks base method. +func (m *MockTx) Prepare(arg0 context.Context, arg1, arg2 string) (*pgconn.StatementDescription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0, arg1, arg2) + ret0, _ := ret[0].(*pgconn.StatementDescription) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockTxMockRecorder) Prepare(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockTx)(nil).Prepare), arg0, arg1, arg2) +} + +// Query mocks base method. +func (m *MockTx) Query(arg0 context.Context, arg1 string, arg2 ...interface{}) (pgx.Rows, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(pgx.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockTxMockRecorder) Query(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTx)(nil).Query), varargs...) +} + +// QueryRow mocks base method. +func (m *MockTx) QueryRow(arg0 context.Context, arg1 string, arg2 ...interface{}) pgx.Row { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(pgx.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow. +func (mr *MockTxMockRecorder) QueryRow(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTx)(nil).QueryRow), varargs...) +} + +// Rollback mocks base method. +func (m *MockTx) Rollback(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockTxMockRecorder) Rollback(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTx)(nil).Rollback), arg0) +} + +// SendBatch mocks base method. +func (m *MockTx) SendBatch(arg0 context.Context, arg1 *pgx.Batch) pgx.BatchResults { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendBatch", arg0, arg1) + ret0, _ := ret[0].(pgx.BatchResults) + return ret0 +} + +// SendBatch indicates an expected call of SendBatch. +func (mr *MockTxMockRecorder) SendBatch(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendBatch", reflect.TypeOf((*MockTx)(nil).SendBatch), arg0, arg1) +} diff --git a/packages/go/dawgs/vendormocks/vendor.go b/packages/go/dawgs/vendormocks/vendor.go index 6ee9e51228..523f8c6e0f 100644 --- a/packages/go/dawgs/vendormocks/vendor.go +++ b/packages/go/dawgs/vendormocks/vendor.go @@ -1,19 +1,20 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package vendormocks //go:generate go run go.uber.org/mock/mockgen -copyright_file=../../../../LICENSE.header -destination=./neo4j/neo4j-go-driver/v5/neo4j/mock.go -package=neo4j github.com/neo4j/neo4j-go-driver/v5/neo4j Result,Transaction,Session +//go:generate go run go.uber.org/mock/mockgen -copyright_file=../../../../LICENSE.header -destination=./jackc/pgx/v5/mock.go -package=pgx github.com/jackc/pgx/v5 Tx diff --git a/packages/go/ein/ad.go b/packages/go/ein/ad.go index db6f927166..b0acf15437 100644 --- a/packages/go/ein/ad.go +++ b/packages/go/ein/ad.go @@ -256,7 +256,7 @@ func ParseDomainTrusts(domain Domain) ParsedDomainTrustData { return parsedData } -// ParseComputerMiscData parses AllowedToDelegate, AllowedToAct, HasSIDHistory,DumpSMSAPassword and Sessions +// ParseComputerMiscData parses AllowedToDelegate, AllowedToAct, HasSIDHistory,DumpSMSAPassword,DCFor and Sessions func ParseComputerMiscData(computer Computer) []IngestibleRelationship { relationships := make([]IngestibleRelationship, 0) for _, target := range computer.AllowedToDelegate { @@ -342,6 +342,17 @@ func ParseComputerMiscData(computer Computer) []IngestibleRelationship { } } + if computer.IsDC && computer.DomainSID != "" { + relationships = append(relationships, IngestibleRelationship{ + Source: computer.ObjectIdentifier, + SourceType: ad.Computer, + TargetType: ad.Domain, + Target: computer.DomainSID, + RelProps: map[string]any{"isacl": false}, + RelType: ad.DCFor, + }) + } + return relationships } diff --git a/packages/go/ein/go.mod b/packages/go/ein/go.mod index 919c1df0bd..f6afc00542 100644 --- a/packages/go/ein/go.mod +++ b/packages/go/ein/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/ein -go 1.20 +go 1.21 require github.com/bloodhoundad/azurehound/v2 v2.0.1 diff --git a/packages/go/ein/incoming_models.go b/packages/go/ein/incoming_models.go index 1b55af1cef..c74ff29c89 100644 --- a/packages/go/ein/incoming_models.go +++ b/packages/go/ein/incoming_models.go @@ -259,6 +259,8 @@ type Computer struct { DCRegistryData DCRegistryData Status ComputerStatus HasSIDHistory []TypedPrincipal + IsDC bool + DomainSID string } type OU struct { diff --git a/packages/go/errors/go.mod b/packages/go/errors/go.mod index ccc82d9984..72067dee7f 100644 --- a/packages/go/errors/go.mod +++ b/packages/go/errors/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/errors -go 1.20 +go 1.21 require github.com/stretchr/testify v1.8.4 diff --git a/packages/go/graphschema/ad/ad.go b/packages/go/graphschema/ad/ad.go index 031a147863..ccfe7bffc8 100644 --- a/packages/go/graphschema/ad/ad.go +++ b/packages/go/graphschema/ad/ad.go @@ -644,7 +644,7 @@ func ACLRelationships() []graph.Kind { return []graph.Kind{AllExtendedRights, ForceChangePassword, AddMember, AddAllowedToAct, GenericAll, WriteDACL, WriteOwner, GenericWrite, ReadLAPSPassword, ReadGMSAPassword, Owns, AddSelf, WriteSPN, AddKeyCredentialLink, GetChanges, GetChangesAll, GetChangesInFilteredSet, WriteAccountRestrictions, SyncLAPSPassword, DCSync, ManageCertificates, ManageCA, Enroll, WritePKIEnrollmentFlag, WritePKINameFlag} } func PathfindingRelationships() []graph.Kind { - return []graph.Kind{Owns, GenericAll, GenericWrite, WriteOwner, WriteDACL, MemberOf, ForceChangePassword, AllExtendedRights, AddMember, HasSession, Contains, GPLink, AllowedToDelegate, TrustedBy, AllowedToAct, AdminTo, CanPSRemote, CanRDP, ExecuteDCOM, HasSIDHistory, AddSelf, DCSync, ReadLAPSPassword, ReadGMSAPassword, DumpSMSAPassword, SQLAdmin, AddAllowedToAct, WriteSPN, AddKeyCredentialLink, SyncLAPSPassword, WriteAccountRestrictions, GoldenCert, ADCSESC1, ADCSESC3, ADCSESC4, ADCSESC5, ADCSESC6, ADCSESC7} + return []graph.Kind{Owns, GenericAll, GenericWrite, WriteOwner, WriteDACL, MemberOf, ForceChangePassword, AllExtendedRights, AddMember, HasSession, Contains, GPLink, AllowedToDelegate, TrustedBy, AllowedToAct, AdminTo, CanPSRemote, CanRDP, ExecuteDCOM, HasSIDHistory, AddSelf, DCSync, ReadLAPSPassword, ReadGMSAPassword, DumpSMSAPassword, SQLAdmin, AddAllowedToAct, WriteSPN, AddKeyCredentialLink, SyncLAPSPassword, WriteAccountRestrictions, GoldenCert, ADCSESC1, ADCSESC3, ADCSESC4, ADCSESC5, ADCSESC6, ADCSESC7, DCFor} } func IsACLKind(s graph.Kind) bool { for _, acl := range ACLRelationships() { diff --git a/packages/go/graphschema/go.mod b/packages/go/graphschema/go.mod index a99664ae0a..f778060dc3 100644 --- a/packages/go/graphschema/go.mod +++ b/packages/go/graphschema/go.mod @@ -16,4 +16,4 @@ module github.com/specterops/bloodhound/graphschema -go 1.20 +go 1.21 diff --git a/packages/go/graphschema/graph.go b/packages/go/graphschema/graph.go index c59259c1e0..34a7f9b11b 100644 --- a/packages/go/graphschema/graph.go +++ b/packages/go/graphschema/graph.go @@ -17,7 +17,7 @@ // Code generated by Cuelang code gen. DO NOT EDIT! // Cuelang source: github.com/specterops/bloodhound/-/tree/main/packages/cue/schemas/ -package schema +package graphschema import graph "github.com/specterops/bloodhound/dawgs/graph" diff --git a/packages/go/graphschema/schema.go b/packages/go/graphschema/schema.go new file mode 100644 index 0000000000..b14e8be6db --- /dev/null +++ b/packages/go/graphschema/schema.go @@ -0,0 +1,159 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package graphschema + +import ( + "github.com/specterops/bloodhound/dawgs/graph" + "github.com/specterops/bloodhound/graphschema/ad" + "github.com/specterops/bloodhound/graphschema/azure" + "github.com/specterops/bloodhound/graphschema/common" +) + +const ( + ActiveDirectoryGraphPrefix = "ad" + AzureGraphPrefix = "az" +) + +func ActiveDirectoryGraphName(suffix string) string { + return ActiveDirectoryGraphPrefix + "_" + suffix +} + +func AzureGraphName(suffix string) string { + return AzureGraphPrefix + "_" + suffix +} + +func CombinedGraphSchema(name string) graph.Graph { + return graph.Graph{ + Name: name, + Nodes: append(common.NodeKinds(), append(azure.NodeKinds(), ad.NodeKinds()...)...), + Edges: append(common.Relationships(), append(azure.Relationships(), ad.Relationships()...)...), + NodeConstraints: []graph.Constraint{{ + Field: common.ObjectID.String(), + Type: graph.BTreeIndex, + }}, + NodeIndexes: []graph.Index{ + { + Field: common.Name.String(), + Type: graph.TextSearchIndex, + }, + { + Field: common.SystemTags.String(), + Type: graph.TextSearchIndex, + }, + { + Field: common.UserTags.String(), + Type: graph.TextSearchIndex, + }, + { + Field: ad.DomainSID.String(), + Type: graph.BTreeIndex, + }, + { + Field: azure.TenantID.String(), + Type: graph.BTreeIndex, + }, + }, + } +} + +func AzureGraphSchema(name string) graph.Graph { + return graph.Graph{ + Name: name, + Nodes: azure.NodeKinds(), + Edges: azure.Relationships(), + NodeConstraints: []graph.Constraint{{ + Field: common.ObjectID.String(), + Type: graph.TextSearchIndex, + }}, + NodeIndexes: []graph.Index{ + { + Field: common.Name.String(), + Type: graph.TextSearchIndex, + }, + { + Field: common.SystemTags.String(), + Type: graph.TextSearchIndex, + }, + { + Field: common.UserTags.String(), + Type: graph.TextSearchIndex, + }, + { + Field: azure.TenantID.String(), + Type: graph.BTreeIndex, + }, + }, + } +} + +func ActiveDirectoryGraphSchema(name string) graph.Graph { + return graph.Graph{ + Name: name, + Nodes: ad.NodeKinds(), + Edges: ad.Relationships(), + NodeConstraints: []graph.Constraint{{ + Field: common.ObjectID.String(), + Type: graph.TextSearchIndex, + }}, + NodeIndexes: []graph.Index{ + { + Field: common.Name.String(), + Type: graph.TextSearchIndex, + }, + { + Field: ad.CertThumbprint.String(), + Type: graph.BTreeIndex, + }, + { + Field: common.SystemTags.String(), + Type: graph.TextSearchIndex, + }, + { + Field: common.UserTags.String(), + Type: graph.TextSearchIndex, + }, + { + Field: ad.DistinguishedName.String(), + Type: graph.BTreeIndex, + }, + { + Field: ad.DomainFQDN.String(), + Type: graph.BTreeIndex, + }, + { + Field: ad.DomainSID.String(), + Type: graph.BTreeIndex, + }, + }, + } +} + +func DefaultGraph() graph.Graph { + return CombinedGraphSchema("default") +} + +func DefaultGraphSchema() graph.Schema { + defaultGraph := DefaultGraph() + + return graph.Schema{ + Graphs: []graph.Graph{ + defaultGraph, + }, + + DefaultGraph: defaultGraph, + } +} diff --git a/packages/go/headers/go.mod b/packages/go/headers/go.mod index ef1026257a..e2170a619b 100644 --- a/packages/go/headers/go.mod +++ b/packages/go/headers/go.mod @@ -16,4 +16,4 @@ module github.com/specterops/bloodhound/headers -go 1.20 +go 1.21 diff --git a/packages/go/headers/headers.go b/packages/go/headers/headers.go index 9115939f9e..54ccca4930 100644 --- a/packages/go/headers/headers.go +++ b/packages/go/headers/headers.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 // Code generated by generate.go; DO NOT EDIT. diff --git a/packages/go/lab/fixture.go b/packages/go/lab/fixture.go index d8ac44e14f..81e5e28dfa 100644 --- a/packages/go/lab/fixture.go +++ b/packages/go/lab/fixture.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package lab @@ -95,7 +95,7 @@ func setDependency(consumer depender, provider depender) error { } } -func SetDependency[T, U any](consumer *Fixture[T], provider *Fixture[U]) error { +func SetDependency(consumer depender, provider depender) error { return setDependency(consumer, provider) } diff --git a/packages/go/lab/go.mod b/packages/go/lab/go.mod index 24627677b7..535ae69e5d 100644 --- a/packages/go/lab/go.mod +++ b/packages/go/lab/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/lab -go 1.20 +go 1.21 require github.com/stretchr/testify v1.8.4 diff --git a/packages/go/lab/logging.go b/packages/go/lab/logging.go index d58c21d773..75e845518c 100644 --- a/packages/go/lab/logging.go +++ b/packages/go/lab/logging.go @@ -1,17 +1,17 @@ // Copyright 2023 Specter Ops, Inc. -// +// // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. // You may obtain a copy of the License at -// +// // http://www.apache.org/licenses/LICENSE-2.0 -// +// // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// +// // SPDX-License-Identifier: Apache-2.0 package lab diff --git a/packages/go/log/go.mod b/packages/go/log/go.mod index fabbbc1fd1..b4c1adb3f6 100644 --- a/packages/go/log/go.mod +++ b/packages/go/log/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/log -go 1.20 +go 1.21 require ( github.com/rs/zerolog v1.29.1 diff --git a/packages/go/log/log.go b/packages/go/log/log.go index 963b284d00..e2072d1954 100644 --- a/packages/go/log/log.go +++ b/packages/go/log/log.go @@ -221,14 +221,21 @@ func Measure(level Level, format string, args ...any) func() { then := time.Now() return func() { - WithLevel(level).Duration(FieldElapsed, time.Since(then)).Msgf(format, args...) + if elapsed := time.Since(then); elapsed >= measureThreshold { + WithLevel(level).Duration(FieldElapsed, elapsed).Msgf(format, args...) + } } } var ( logMeasurePairCounter = atomic.Uint64{} + measureThreshold = time.Second ) +func SetMeasureThreshold(newMeasureThreshold time.Duration) { + measureThreshold = newMeasureThreshold +} + func LogAndMeasure(level Level, format string, args ...any) func() { var ( pairID = logMeasurePairCounter.Add(1) @@ -236,9 +243,12 @@ func LogAndMeasure(level Level, format string, args ...any) func() { then = time.Now() ) - WithLevel(level).Uint64(FieldMeasurementID, pairID).Msg(message) + // Only output the message header on debug + WithLevel(LevelDebug).Uint64(FieldMeasurementID, pairID).Msg(message) return func() { - WithLevel(level).Duration(FieldElapsed, time.Since(then)).Uint64(FieldMeasurementID, pairID).Msg(message) + if elapsed := time.Since(then); elapsed >= measureThreshold { + WithLevel(level).Duration(FieldElapsed, elapsed).Uint64(FieldMeasurementID, pairID).Msg(message) + } } } diff --git a/packages/go/mediatypes/go.mod b/packages/go/mediatypes/go.mod index 344a16a4d9..53d344ea99 100644 --- a/packages/go/mediatypes/go.mod +++ b/packages/go/mediatypes/go.mod @@ -16,4 +16,4 @@ module github.com/specterops/bloodhound/mediatypes -go 1.20 +go 1.21 diff --git a/packages/go/params/go.mod b/packages/go/params/go.mod index d5d79781fb..b5e54806df 100644 --- a/packages/go/params/go.mod +++ b/packages/go/params/go.mod @@ -16,6 +16,6 @@ module github.com/specterops/bloodhound/params -go 1.20 +go 1.21 require github.com/gorilla/mux v1.8.0 diff --git a/packages/go/schemagen/go.mod b/packages/go/schemagen/go.mod index a45af43004..0186dc02d6 100644 --- a/packages/go/schemagen/go.mod +++ b/packages/go/schemagen/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/schemagen -go 1.20 +go 1.21 require ( cuelang.org/go v0.5.0 diff --git a/packages/go/schemagen/main.go b/packages/go/schemagen/main.go index 4290779e45..b70ad50771 100644 --- a/packages/go/schemagen/main.go +++ b/packages/go/schemagen/main.go @@ -34,7 +34,7 @@ type Schema struct { } func GenerateGolang(projectRoot string, rootSchema Schema) error { - if err := generator.GenerateGolangSchemaTypes("schema", filepath.Join(projectRoot, "packages/go/graphschema")); err != nil { + if err := generator.GenerateGolangSchemaTypes("graphschema", filepath.Join(projectRoot, "packages/go/graphschema")); err != nil { return err } diff --git a/packages/go/slices/go.mod b/packages/go/slices/go.mod index b353e8c81d..4571e4be58 100644 --- a/packages/go/slices/go.mod +++ b/packages/go/slices/go.mod @@ -16,7 +16,7 @@ module github.com/specterops/bloodhound/slices -go 1.20 +go 1.21 require github.com/stretchr/testify v1.8.4 diff --git a/packages/go/stbernard/go.mod b/packages/go/stbernard/go.mod index 94651e09c5..4d9829afea 100644 --- a/packages/go/stbernard/go.mod +++ b/packages/go/stbernard/go.mod @@ -16,6 +16,6 @@ module github.com/specterops/bloodhound/packages/go/stbernard -go 1.20 +go 1.21 require golang.org/x/mod v0.11.0 diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/ADCSESC6a.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/ADCSESC6a.tsx new file mode 100644 index 0000000000..35d0f7205c --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/ADCSESC6a.tsx @@ -0,0 +1,31 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import General from './General'; +import WindowsAbuse from './WindowsAbuse'; +import LinuxAbuse from './LinuxAbuse'; +import Opsec from './Opsec'; +import References from './References'; + +const ADCSESC6a = { + general: General, + windowsAbuse: WindowsAbuse, + linuxAbuse: LinuxAbuse, + opsec: Opsec, + references: References, +}; + +export default ADCSESC6a; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/General.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/General.tsx new file mode 100644 index 0000000000..1fa2df4a11 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/General.tsx @@ -0,0 +1,44 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { groupSpecialFormat } from '../utils'; +import { EdgeInfoProps } from '../index'; +import { Typography } from '@mui/material'; + +const General: FC = ({ sourceName, sourceType, targetName }) => { + return ( + <> + + {groupSpecialFormat(sourceType, sourceName)} the privileges to perform the ADCS ESC6 Scenario A attack + against the target domain. + + + The principal has permission to enroll on one or more certificate templates allowing for authentication. + They also have enrollment permission for an enterprise CA with the necessary templates published. This + enterprise CA is trusted for NT authentication in the forest, and chains up to a root CA for the forest. + The enterprise CA is configured with the EDITF_ATTRIBUTESUBJECTALTNAME2 flag allowing enrollees to + specify a Subject Alternate Name (SAN) identifying another principal during certificate enrollment of + any published certificate template. This setup allow an attacker principal to obtain a malicious + certificate as another principal. There is an affected Domain Controller configured to allow weak + certificate binding enforcement, which enables the attacker principal to authenticate with the malicious + certificate and thereby impersonating any AD forest user or computer without their credentials. + + + ); +}; + +export default General; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/LinuxAbuse.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/LinuxAbuse.tsx new file mode 100644 index 0000000000..874627307e --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/LinuxAbuse.tsx @@ -0,0 +1,42 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const LinuxAbuse: FC = () => { + return ( + <> + An attacker may perform this attack in the following steps: + + Step 1: Use Certipy to request enrollment in the affected template, specifying the affected + enterprise CA and target principal to impersonate: + + + { + 'certipy req -u john@corp.local -p Passw0rd -ca corp-DC-CA -target ca.corp.local -template ESC6 -upn administrator@corp.local' + } + + + Step 2: Request a ticket granting ticket (TGT) from the domain, specifying the certificate + created in Step 1 and the IP of a domain controller: + + {'certipy auth -pfx administrator.pfx -dc-ip 172.16.126.128'} + + ); +}; + +export default LinuxAbuse; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/Opsec.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/Opsec.tsx new file mode 100644 index 0000000000..0200106041 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/Opsec.tsx @@ -0,0 +1,31 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const Opsec: FC = () => { + return ( + + When the affected certificate authority issues the certificate to the attacker, it will retain a local copy + of that certificate in its issued certificates store. Defenders may analyze those issued certificates to + identify illegitimately issued certificates and identify the principal that requested the certificate, as + well as the target identity the attacker is attempting to impersonate. + + ); +}; + +export default Opsec; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/References.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/References.tsx new file mode 100644 index 0000000000..ff137d162c --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/References.tsx @@ -0,0 +1,47 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Link, Box } from '@mui/material'; + +const References: FC = () => { + return ( + + + Certified Pre-Owned + +
+ + Certipy 4.0 + +
+ + Domain Escalation Edit Attributes + +
+ ); +}; + +export default References; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/WindowsAbuse.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/WindowsAbuse.tsx new file mode 100644 index 0000000000..12e36054d6 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/ADCSESC6a/WindowsAbuse.tsx @@ -0,0 +1,52 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const WindowsAbuse: FC = () => { + return ( + <> + An attacker may perform this attack in the following steps: + + Step 1: Use Certify to request enrollment in the affected template, specifying the affected + enterprise CA and target principal to impersonate: + + + { + '.\\Certify.exe request /ca:rootdomaindc.forestroot.com\\forestroot-RootDomainDC-CA /template:ESC6 /altname:forestroot\\ForestRootDA' + } + + + Step 2: Convert the emitted certificate to PFX format: + + {'certutil.exe -MergePFX .cert.pem .cert.pfx'} + + Step 3: Use Rubeus to request a ticket granting ticket (TGT) from the domain, specifying the + target identity to impersonate and the PFX-formatted certificate created in Step 2: + + + {'.\\Rubeus.exe asktgt /certificate:cert.pfx /user:”forestroot\\forestrootda” /password:asdf /ptt'} + + + Step 4: Optionally verify the TGT by listing it with the klist command: + + {'klist'} + + ); +}; + +export default WindowsAbuse; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/CanAbuseUPNCertMapping.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/CanAbuseUPNCertMapping.tsx new file mode 100644 index 0000000000..354fc850a9 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/CanAbuseUPNCertMapping.tsx @@ -0,0 +1,31 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import General from './General'; +import WindowsAbuse from './WindowsAbuse'; +import LinuxAbuse from './LinuxAbuse'; +import Opsec from './Opsec'; +import References from './References'; + +const CanAbuseUPNCertMapping = { + general: General, + windowsAbuse: WindowsAbuse, + linuxAbuse: LinuxAbuse, + opsec: Opsec, + references: References, +}; + +export default CanAbuseUPNCertMapping; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/General.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/General.tsx new file mode 100644 index 0000000000..26a2231d33 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/General.tsx @@ -0,0 +1,31 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { EdgeInfoProps } from '../index'; +import { Typography } from '@mui/material'; + +const General: FC = ({ sourceName, sourceType, targetName }) => { + return ( + + This edge is created when BloodHound identifies a domain controller with particular certificate mapping + methods configured in the registry. This edge alone is not enough to perform an abuse, but may be part of + several other node and edge configurations that create the conditions for abusable ADCS edges. + + ); +}; + +export default General; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/LinuxAbuse.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/LinuxAbuse.tsx new file mode 100644 index 0000000000..14266ff24a --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/LinuxAbuse.tsx @@ -0,0 +1,29 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const LinuxAbuse: FC = () => { + return ( + + An attacker may perform an ADCS ESC6 or ESC10 attack that relies on this relationship. This relationship + alone is not enough to escalate rights or impersonate other principals. + + ); +}; + +export default LinuxAbuse; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/Opsec.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/Opsec.tsx new file mode 100644 index 0000000000..0200106041 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/Opsec.tsx @@ -0,0 +1,31 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const Opsec: FC = () => { + return ( + + When the affected certificate authority issues the certificate to the attacker, it will retain a local copy + of that certificate in its issued certificates store. Defenders may analyze those issued certificates to + identify illegitimately issued certificates and identify the principal that requested the certificate, as + well as the target identity the attacker is attempting to impersonate. + + ); +}; + +export default Opsec; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/References.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/References.tsx new file mode 100644 index 0000000000..860e6a56d7 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/References.tsx @@ -0,0 +1,40 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Link, Box } from '@mui/material'; + +const References: FC = () => { + return ( + + + Certified Pre-Owned + +
+ + Certipy 4.0 + +
+ ); +}; + +export default References; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/WindowsAbuse.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/WindowsAbuse.tsx new file mode 100644 index 0000000000..8a2adb5f30 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseUPNCertMapping/WindowsAbuse.tsx @@ -0,0 +1,29 @@ +// Copyright 2024 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const WindowsAbuse: FC = () => { + return ( + + An attacker may perform an ADCS ESC6 or ESC10 attack that relies on this relationship. This relationship + alone is not enough to escalate rights or impersonate other principals. + + ); +}; + +export default WindowsAbuse; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/CanAbuseWeakCertBinding.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/CanAbuseWeakCertBinding.tsx new file mode 100644 index 0000000000..f83a3d4c75 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/CanAbuseWeakCertBinding.tsx @@ -0,0 +1,31 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import General from './General'; +import WindowsAbuse from './WindowsAbuse'; +import LinuxAbuse from './LinuxAbuse'; +import Opsec from './Opsec'; +import References from './References'; + +const ADCSESC1 = { + general: General, + windowsAbuse: WindowsAbuse, + linuxAbuse: LinuxAbuse, + opsec: Opsec, + references: References, +}; + +export default ADCSESC1; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/General.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/General.tsx new file mode 100644 index 0000000000..4e49c91cb5 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/General.tsx @@ -0,0 +1,31 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { EdgeInfoProps } from '../index'; +import { Typography } from '@mui/material'; + +const General: FC = ({ sourceName, sourceType, targetName }) => { + return ( + + This edge is created when BloodHound identifies a domain controller with a particular certificate binding + enforcement configuration in the registry. This edge alone is not enough to perform an abuse, but may be + part of several other node and edge configurations that create the conditions for abusable ADCS edges. + + ); +}; + +export default General; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/LinuxAbuse.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/LinuxAbuse.tsx new file mode 100644 index 0000000000..80c80eb16c --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/LinuxAbuse.tsx @@ -0,0 +1,29 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const LinuxAbuse: FC = () => { + return ( + + An attacker may perform an ADCS ESC6 or ESC9 attack that relies on this relationship. This relationship + alone is not enough to escalate rights or impersonate other principals. + + ); +}; + +export default LinuxAbuse; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/Opsec.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/Opsec.tsx new file mode 100644 index 0000000000..7a43445c57 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/Opsec.tsx @@ -0,0 +1,31 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const Opsec: FC = () => { + return ( + + When the affected certificate authority issues the certificate to the attacker, it will retain a local copy + of that certificate in its issued certificates store. Defenders may analyze those issued certificates to + identify illegitimately issued certificates and identify the principal that requested the certificate, as + well as the target identity the attacker is attempting to impersonate. + + ); +}; + +export default Opsec; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/References.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/References.tsx new file mode 100644 index 0000000000..d254bf2f6c --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/References.tsx @@ -0,0 +1,40 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Link, Box } from '@mui/material'; + +const References: FC = () => { + return ( + + + Certified Pre-Owned + +
+ + Certipy 4.0 + +
+ ); +}; + +export default References; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/WindowsAbuse.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/WindowsAbuse.tsx new file mode 100644 index 0000000000..f07b003205 --- /dev/null +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/CanAbuseWeakCertBinding/WindowsAbuse.tsx @@ -0,0 +1,29 @@ +// Copyright 2023 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +import { FC } from 'react'; +import { Typography } from '@mui/material'; + +const WindowsAbuse: FC = () => { + return ( + + An attacker may perform an ADCS ESC6 or ESC9 attack that relies on this relationship. This relationship + alone is not enough to escalate rights or impersonate other principals. + + ); +}; + +export default WindowsAbuse; diff --git a/packages/javascript/bh-shared-ui/src/components/HelpTexts/index.tsx b/packages/javascript/bh-shared-ui/src/components/HelpTexts/index.tsx index c898b7d475..f94169e206 100644 --- a/packages/javascript/bh-shared-ui/src/components/HelpTexts/index.tsx +++ b/packages/javascript/bh-shared-ui/src/components/HelpTexts/index.tsx @@ -1,4 +1,4 @@ -// Copyright 2023 Specter Ops, Inc. +// Copyright 2024 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. @@ -64,6 +64,8 @@ import AdminTo from './AdminTo/AdminTo'; import AllExtendedRights from './AllExtendedRights/AllExtendedRights'; import AllowedToAct from './AllowedToAct/AllowedToAct'; import AllowedToDelegate from './AllowedToDelegate/AllowedToDelegate'; +import CanAbuseUPNCertMapping from './CanAbuseUPNCertMapping/CanAbuseUPNCertMapping'; +import CanAbuseWeakCertBinding from './CanAbuseWeakCertBinding/CanAbuseWeakCertBinding'; import CanPSRemote from './CanPSRemote/CanPSRemote'; import CanRDP from './CanRDP/CanRDP'; import Contains from './Contains/Contains'; @@ -107,6 +109,7 @@ import WritePKIEnrollmentFlag from './WritePKIEnrollmentFlag/WritePKIEnrollmentF import WritePKINameFlag from './WritePKINameFlag/WritePKINameFlag'; import WriteSPN from './WriteSPN/WriteSPN'; import ADCSESC1 from './ADCSESC1/ADCSESC1'; +import ADCSESC6a from './ADCSESC6a/ADCSESC6a'; import ADCSESC6b from './ADCSESC6b/ADCSESC6b'; export type EdgeInfoProps = { @@ -145,6 +148,8 @@ const EdgeInfoComponents = { ReadGMSAPassword: ReadGMSAPassword, HasSIDHistory: HasSIDHistory, TrustedBy: TrustedBy, + CanAbuseUPNCertMapping: CanAbuseUPNCertMapping, + CanAbuseWeakCertBinding: CanAbuseWeakCertBinding, CanPSRemote: CanPSRemote, AZAddMembers: AZAddMembers, AZAddSecret: AZAddSecret, @@ -208,6 +213,7 @@ const EdgeInfoComponents = { GoldenCert: GoldenCert, ADCSESC1: ADCSESC1, ADCSESC3: ADCSESC3, + ADCSESC6a: ADCSESC6a, ADCSESC6b: ADCSESC6b, ManageCA: ManageCA, ManageCertificates: ManageCertificates, diff --git a/packages/javascript/bh-shared-ui/src/graphSchema.ts b/packages/javascript/bh-shared-ui/src/graphSchema.ts index 04c748bcc2..05ddc0b561 100644 --- a/packages/javascript/bh-shared-ui/src/graphSchema.ts +++ b/packages/javascript/bh-shared-ui/src/graphSchema.ts @@ -520,6 +520,7 @@ export function ActiveDirectoryPathfindingEdges(): ActiveDirectoryRelationshipKi ActiveDirectoryRelationshipKind.ADCSESC5, ActiveDirectoryRelationshipKind.ADCSESC6, ActiveDirectoryRelationshipKind.ADCSESC7, + ActiveDirectoryRelationshipKind.DCFor, ]; } export enum AzureNodeKind { diff --git a/packages/python/beagle/beagle/plan/golang.py b/packages/python/beagle/beagle/plan/golang.py index 7916aeced4..98788e4018 100644 --- a/packages/python/beagle/beagle/plan/golang.py +++ b/packages/python/beagle/beagle/plan/golang.py @@ -1,17 +1,17 @@ # Copyright 2023 Specter Ops, Inc. -# +# # Licensed under the Apache License, Version 2.0 # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# +# # SPDX-License-Identifier: Apache-2.0 import os @@ -123,6 +123,12 @@ def list(cls, path: str, recursive: bool = True) -> "List[GoModule]": module_list_output = run(cmd=["go", "list", "-json", path_target], cwd=path, capture_stderr=True) for module_json in json_multi_loads(module_list_output): + go_files = module_json.get("GoFiles") + + if go_files is None: + print(f"Module {module_json['Name']} does not repot having any go files. Skipping.") + continue + if isinstance(module_json, list): raise Exception("Unexpected type during JSON deserialization: expected a dict but got a list.") diff --git a/packages/python/beagle/beagle/test.py b/packages/python/beagle/beagle/test.py index a744da6025..864ae73e0d 100644 --- a/packages/python/beagle/beagle/test.py +++ b/packages/python/beagle/beagle/test.py @@ -1,17 +1,17 @@ # Copyright 2023 Specter Ops, Inc. -# +# # Licensed under the Apache License, Version 2.0 # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# +# # SPDX-License-Identifier: Apache-2.0 import os @@ -69,7 +69,6 @@ def _start_integration_test_services(project_ctx: ProjectContext) -> None: # Wait 15 seconds for the services to online time.sleep(15) - def _stop_integration_test_services(project_ctx: ProjectContext) -> None: project_ctx.info("Stopping integration test services") diff --git a/tools/docker-compose/api.Dockerfile b/tools/docker-compose/api.Dockerfile index 14d951c34c..7919196755 100644 --- a/tools/docker-compose/api.Dockerfile +++ b/tools/docker-compose/api.Dockerfile @@ -62,7 +62,7 @@ WORKDIR /tmp/azurehound/artifacts RUN 7z a -tzip -mx9 azurehound-$AZUREHOUND_VERSION.zip azurehound-* RUN sha256sum azurehound-$AZUREHOUND_VERSION.zip > azurehound-$AZUREHOUND_VERSION.zip.sha256 -FROM docker.io/library/golang:1.20 +FROM docker.io/library/golang:1.21 ARG SHARPHOUND_VERSION ARG AZUREHOUND_VERSION ENV GOFLAGS="-buildvcs=false"