diff --git a/pkg/construct2/graph.go b/pkg/construct2/graph.go index 58e6bc1f9..8352084e8 100644 --- a/pkg/construct2/graph.go +++ b/pkg/construct2/graph.go @@ -33,7 +33,7 @@ func NewGraph(options ...func(*graph.Traits)) Graph { } func NewAcyclicGraph(options ...func(*graph.Traits)) Graph { - return NewGraph(graph.PreventCycles()) + return NewGraphWithOptions(append(options, graph.Directed(), graph.PreventCycles())...) } func ResourceHasher(r *Resource) ResourceId { diff --git a/pkg/construct2/graph_io.go b/pkg/construct2/graph_io.go index 530563450..169cf8c70 100644 --- a/pkg/construct2/graph_io.go +++ b/pkg/construct2/graph_io.go @@ -166,9 +166,7 @@ func (e SimpleEdge) Less(other SimpleEdge) bool { return ResourceIdLess(e.Target, other.Target) } -func (e *SimpleEdge) UnmarshalText(data []byte) error { - s := string(data) - +func (e *SimpleEdge) Parse(s string) error { source, target, found := strings.Cut(s, " -> ") if !found { target, source, found = strings.Cut(s, " <- ") @@ -176,10 +174,21 @@ func (e *SimpleEdge) UnmarshalText(data []byte) error { return errors.New("invalid edge format, expected either `source -> target` or `target <- source`") } } + return errors.Join( + e.Source.Parse(source), + e.Target.Parse(target), + ) +} + +func (e *SimpleEdge) Validate() error { + return errors.Join(e.Source.Validate(), e.Target.Validate()) +} - srcErr := e.Source.UnmarshalText([]byte(source)) - tgtErr := e.Target.UnmarshalText([]byte(target)) - return errors.Join(srcErr, tgtErr) +func (e *SimpleEdge) UnmarshalText(data []byte) error { + if err := e.Parse(string(data)); err != nil { + return err + } + return e.Validate() } func (e SimpleEdge) ToEdge() Edge { diff --git a/pkg/construct2/graphtest/graph.go b/pkg/construct2/graphtest/graph.go index dad9ba6d2..0b00d50fe 100644 --- a/pkg/construct2/graphtest/graph.go +++ b/pkg/construct2/graphtest/graph.go @@ -57,25 +57,84 @@ func AssertGraphContains(t *testing.T, expect, actual construct2.Graph) { } } -func stringToGraphElement(e string) (any, error) { +func StringToGraphElement(e string) (any, error) { var id construct2.ResourceId - idErr := id.UnmarshalText([]byte(e)) - if idErr == nil { + idErr := id.Parse(e) + if id.Validate() == nil { return id, nil } - var edge construct2.SimpleEdge - edgeErr := edge.UnmarshalText([]byte(e)) - if edgeErr == nil { - return edge, nil - } var path construct2.Path - pathErr := path.UnmarshalText([]byte(e)) - if pathErr == nil { + pathErr := path.Parse(e) + if len(path) > 0 { return path, nil } - return nil, errors.Join(idErr, edgeErr, pathErr) + return nil, errors.Join(idErr, pathErr) +} + +// AddElement is a utility function for adding an element to a graph. See [MakeGraph] for more information on supported +// element types. Returns whether adding the element failed. +func AddElement(t *testing.T, g construct2.Graph, e any) (failed bool) { + must := func(err error) { + if err != nil { + t.Fatal(err) + } + } + if estr, ok := e.(string); ok { + var err error + e, err = StringToGraphElement(estr) + if err != nil { + t.Errorf("invalid element %q (type %[1]T) Parse errors: %v", e, err) + return true + } + } + + addIfMissing := func(res *construct2.Resource) { + if _, err := g.Vertex(res.ID); errors.Is(err, graph.ErrVertexNotFound) { + must(g.AddVertex(res)) + } else if err != nil { + t.Fatal(fmt.Errorf("could check vertex %s: %w", res.ID, err)) + } + } + + switch e := e.(type) { + case construct2.ResourceId: + addIfMissing(&construct2.Resource{ID: e}) + + case construct2.Resource: + must(g.AddVertex(&e)) + + case *construct2.Resource: + must(g.AddVertex(e)) + + case construct2.Edge: + addIfMissing(&construct2.Resource{ID: e.Source}) + addIfMissing(&construct2.Resource{ID: e.Target}) + must(g.AddEdge(e.Source, e.Target)) + + case construct2.ResourceEdge: + addIfMissing(e.Source) + addIfMissing(e.Target) + must(g.AddEdge(e.Source.ID, e.Target.ID)) + + case construct2.SimpleEdge: + addIfMissing(&construct2.Resource{ID: e.Source}) + addIfMissing(&construct2.Resource{ID: e.Target}) + must(g.AddEdge(e.Source, e.Target)) + + case construct2.Path: + for i, id := range e { + addIfMissing(&construct2.Resource{ID: id}) + if i > 0 { + must(g.AddEdge(e[i-1], id)) + } + } + default: + t.Errorf("invalid element of type %T", e) + return true + } + return false } // MakeGraph is a utility function for creating a graph from a list of elements which can be of types: @@ -92,62 +151,11 @@ func stringToGraphElement(e string) (any, error) { // return MakeGraph(t, NewGraph(), elements...) // } func MakeGraph(t *testing.T, g construct2.Graph, elements ...any) construct2.Graph { - must := func(err error) { - if err != nil { - t.Fatal(err) - } - } - addIfMissing := func(res *construct2.Resource) { - if _, err := g.Vertex(res.ID); errors.Is(err, graph.ErrVertexNotFound) { - must(g.AddVertex(res)) - } else if err != nil { - t.Fatal(fmt.Errorf("could check vertex %s: %w", res.ID, err)) - } - } failed := false for i, e := range elements { - if estr, ok := e.(string); ok { - var err error - e, err = stringToGraphElement(estr) - if err != nil { - t.Errorf("invalid element[%d] %q (type %[2]T) Parse errors: %v", i, e, err) - failed = true - } - } - switch e := e.(type) { - case construct2.ResourceId: - addIfMissing(&construct2.Resource{ID: e}) - - case construct2.Resource: - must(g.AddVertex(&e)) - - case *construct2.Resource: - must(g.AddVertex(e)) - - case construct2.Edge: - addIfMissing(&construct2.Resource{ID: e.Source}) - addIfMissing(&construct2.Resource{ID: e.Target}) - must(g.AddEdge(e.Source, e.Target)) - - case construct2.ResourceEdge: - addIfMissing(e.Source) - addIfMissing(e.Target) - must(g.AddEdge(e.Source.ID, e.Target.ID)) - - case construct2.SimpleEdge: - addIfMissing(&construct2.Resource{ID: e.Source}) - addIfMissing(&construct2.Resource{ID: e.Target}) - must(g.AddEdge(e.Source, e.Target)) - - case construct2.Path: - for i, id := range e { - addIfMissing(&construct2.Resource{ID: id}) - if i > 0 { - must(g.AddEdge(e[i-1], id)) - } - } - default: - t.Errorf("invalid element[%d] of type %T", i, e) + elemFailed := AddElement(t, g, e) + if elemFailed { + t.Errorf("failed to add element[%d] (%v) to graph", i, e) failed = true } } diff --git a/pkg/construct2/graphtest/ids.go b/pkg/construct2/graphtest/ids.go index 9ce3316ca..08e4eed2f 100644 --- a/pkg/construct2/graphtest/ids.go +++ b/pkg/construct2/graphtest/ids.go @@ -7,7 +7,7 @@ import ( ) func ParseId(t *testing.T, str string) (id construct.ResourceId) { - err := id.UnmarshalText([]byte(str)) + err := id.Parse(str) if err != nil { t.Fatalf("failed to parse resource id %q: %v", str, err) } @@ -16,7 +16,7 @@ func ParseId(t *testing.T, str string) (id construct.ResourceId) { func ParseEdge(t *testing.T, str string) construct.Edge { var io construct.SimpleEdge - err := io.UnmarshalText([]byte(str)) + err := io.Parse(str) if err != nil { t.Fatalf("failed to parse edge %q: %v", str, err) } @@ -28,7 +28,7 @@ func ParseEdge(t *testing.T, str string) construct.Edge { func ParseRef(t *testing.T, str string) construct.PropertyRef { var ref construct.PropertyRef - err := ref.UnmarshalText([]byte(str)) + err := ref.Parse(str) if err != nil { t.Fatalf("failed to parse property ref %q: %v", str, err) } @@ -37,7 +37,7 @@ func ParseRef(t *testing.T, str string) construct.PropertyRef { func ParsePath(t *testing.T, str string) construct.Path { var path construct.Path - err := path.UnmarshalText([]byte(str)) + err := path.Parse(str) if err != nil { t.Fatalf("failed to parse path %q: %v", str, err) } diff --git a/pkg/construct2/paths.go b/pkg/construct2/paths.go index dc2a38328..a8b69476a 100644 --- a/pkg/construct2/paths.go +++ b/pkg/construct2/paths.go @@ -42,18 +42,37 @@ func (p Path) MarshalText() ([]byte, error) { return []byte(p.String()), nil } -func (p *Path) UnmarshalText(text []byte) error { - parts := strings.Split(string(text), " -> ") +func (p *Path) Parse(s string) error { + parts := strings.Split(s, " -> ") *p = make(Path, len(parts)) + var errs error for i, part := range parts { var id ResourceId - err := id.UnmarshalText([]byte(part)) + err := id.Parse(part) if err != nil { - return err + errs = errors.Join(errs, fmt.Errorf("could not parse path[%d]: %w", i, err)) } (*p)[i] = id } - return nil + return errs +} + +func (p *Path) Validate() error { + var errs error + for i, id := range *p { + err := id.Validate() + if err != nil { + errs = errors.Join(errs, fmt.Errorf("path[%d] invalid: %w", i, err)) + } + } + return errs +} + +func (p *Path) UnmarshalText(text []byte) error { + if err := p.Parse(string(text)); err != nil { + return err + } + return p.Validate() } func (d *Dependencies) Add(p Path) { diff --git a/pkg/construct2/property_ref.go b/pkg/construct2/property_ref.go index 0737804fd..c6ff70860 100644 --- a/pkg/construct2/property_ref.go +++ b/pkg/construct2/property_ref.go @@ -1,8 +1,8 @@ package construct2 import ( - "bytes" "fmt" + "strings" ) type PropertyRef struct { @@ -18,15 +18,22 @@ func (v PropertyRef) MarshalText() ([]byte, error) { return []byte(v.String()), nil } -func (v *PropertyRef) UnmarshalText(b []byte) error { - parts := bytes.SplitN(b, []byte("#"), 2) - if len(parts) != 2 { - return fmt.Errorf("invalid PropertyRef format: %s", string(b)) +func (v *PropertyRef) Parse(s string) error { + res, prop, ok := strings.Cut(s, "#") + if !ok { + return fmt.Errorf("invalid PropertyRef format: %s", s) } - err := v.Resource.UnmarshalText(parts[0]) - if err != nil { + v.Property = prop + return v.Resource.Parse(res) +} + +func (v *PropertyRef) Validate() error { + return v.Resource.Validate() +} + +func (v *PropertyRef) UnmarshalText(b []byte) error { + if err := v.Parse(string(b)); err != nil { return err } - v.Property = string(parts[1]) - return nil + return v.Validate() } diff --git a/pkg/construct2/resource_id.go b/pkg/construct2/resource_id.go index 4f51f1ce2..0c83241c4 100644 --- a/pkg/construct2/resource_id.go +++ b/pkg/construct2/resource_id.go @@ -153,8 +153,8 @@ var ( resourceNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_./\-:\[\]]*$`) ) -func (id *ResourceId) UnmarshalText(data []byte) error { - parts := strings.SplitN(string(data), ":", 4) +func (id *ResourceId) Parse(s string) error { + parts := strings.SplitN(s, ":", 4) switch len(parts) { case 4: id.Name = parts[3] @@ -174,6 +174,10 @@ func (id *ResourceId) UnmarshalText(data []byte) error { return fmt.Errorf("must have trailing ':' for provider-only ID") } } + return nil +} + +func (id *ResourceId) Validate() error { if id.IsZero() { return nil } @@ -191,11 +195,19 @@ func (id *ResourceId) UnmarshalText(data []byte) error { err = errors.Join(err, fmt.Errorf("invalid name '%s' (must match %s)", id.Name, resourceNamePattern)) } if err != nil { - return fmt.Errorf("invalid resource id '%s': %w", string(data), err) + return fmt.Errorf("invalid resource id '%s': %w", id.String(), err) } return nil } +func (id *ResourceId) UnmarshalText(data []byte) error { + err := id.Parse(string(data)) + if err != nil { + return err + } + return id.Validate() +} + func (id ResourceId) MarshalTOML() ([]byte, error) { return id.MarshalText() } diff --git a/pkg/engine2/path_selection/candidate_weight.go b/pkg/engine2/path_selection/candidate_weight.go index 89b5dc013..ff1ad29ab 100644 --- a/pkg/engine2/path_selection/candidate_weight.go +++ b/pkg/engine2/path_selection/candidate_weight.go @@ -12,18 +12,23 @@ import ( ) // determineCandidateWeight determines the weight of a candidate resource based on its relationship to the src and target resources -// and if it is already in the result graph +// and if it is already in the result graph. // // The weight is determined by the following: // 1. If the candidate is downstream of the src or upstream of the target, add 10 to the weight // 2. If the candidate is in the result graph, add 9 to the weight // 3. if the candidate is existing determine how close it is to the src and target resources for additional weighting +// +// 'undirected' is from the 'ctx' raw view, but given as an argument here to avoid having to recompute it. +// 'desc' return is purely for debugging purposes, describing the weight calculation. func determineCandidateWeight( ctx solution_context.SolutionContext, src, target construct.ResourceId, id construct.ResourceId, resultGraph construct.Graph, + undirected construct.Graph, ) (weight int, errs error) { + // note(gg) perf: these Downstream/Upstream functions don't need the full list and don't need to run twice downstreams, err := solution_context.Downstream(ctx, src, knowledgebase.ResourceDirectLayer) errs = errors.Join(errs, err) if collectionutil.Contains(downstreams, id) { @@ -53,11 +58,6 @@ func determineCandidateWeight( weight += 9 } - undirected, err := BuildUndirectedGraph(ctx) - if err != nil { - errs = errors.Join(errs, err) - return - } pather, err := construct.ShortestPaths(undirected, id, construct.DontSkipEdges) if err != nil { errs = errors.Join(errs, err) @@ -76,7 +76,6 @@ func determineCandidateWeight( } else { availableWeight -= 1 } - } shortestPath, err = pather.ShortestPath(target) @@ -99,25 +98,26 @@ func determineCandidateWeight( return } -func BuildUndirectedGraph(ctx solution_context.SolutionContext) (construct.Graph, error) { +func BuildUndirectedGraph(g construct.Graph, kb knowledgebase.TemplateKB) (construct.Graph, error) { undirected := graph.NewWithStore( construct.ResourceHasher, graph_addons.NewMemoryStore[construct.ResourceId, *construct.Resource](), + graph.Weighted(), ) - err := undirected.AddVerticesFrom(ctx.RawView()) + err := undirected.AddVerticesFrom(g) if err != nil { return nil, err } - edges, err := ctx.RawView().Edges() + edges, err := g.Edges() if err != nil { return nil, err } for _, e := range edges { weight := 1 // increase weights for edges that are connected to a functional resource - if knowledgebase.GetFunctionality(ctx.KnowledgeBase(), e.Source) != knowledgebase.Unknown { + if knowledgebase.GetFunctionality(kb, e.Source) != knowledgebase.Unknown { weight = 1000 - } else if knowledgebase.GetFunctionality(ctx.KnowledgeBase(), e.Target) != knowledgebase.Unknown { + } else if knowledgebase.GetFunctionality(kb, e.Target) != knowledgebase.Unknown { weight = 1000 } err := undirected.AddEdge(e.Source, e.Target, graph.EdgeWeight(weight)) diff --git a/pkg/engine2/path_selection/candidate_weight_test.go b/pkg/engine2/path_selection/candidate_weight_test.go new file mode 100644 index 000000000..eb87353ad --- /dev/null +++ b/pkg/engine2/path_selection/candidate_weight_test.go @@ -0,0 +1,159 @@ +package path_selection + +import ( + "testing" + + construct "github.com/klothoplatform/klotho/pkg/construct2" + "github.com/klothoplatform/klotho/pkg/construct2/graphtest" + "github.com/klothoplatform/klotho/pkg/engine2/enginetesting" + knowledgebase "github.com/klothoplatform/klotho/pkg/knowledge_base2" + "github.com/klothoplatform/klotho/pkg/knowledge_base2/kbtesting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_determineCandidateWeight(t *testing.T) { + // NOTE(gg): this test is a little brittle, since the weights don't really have any meaning other than + // their relative values. They're made up to get the desired results in path selection. + tests := []struct { + name string + graph []any + resultGraph []any + src, target string + id string + wantWeight int + wantErr bool + }{ + { + name: "no relation", + graph: []any{"p:compute:a -> p:glue:b -> p:glue:c -> p:compute:d"}, + src: "p:compute:a", + target: "p:compute:d", + id: "p:compute:e", + wantWeight: 2, + }, + { + name: "no relation, in result graph", + graph: []any{"p:compute:a -> p:glue:b -> p:glue:c -> p:compute:d"}, + resultGraph: []any{"p:compute:e"}, + src: "p:compute:a", + target: "p:compute:d", + id: "p:compute:e", + wantWeight: 11, + }, + { + name: "downstream direct", + graph: []any{"p:compute:a -> p:glue:b -> p:glue:c -> p:compute:d"}, + src: "p:compute:a", + target: "p:compute:d", + id: "p:glue:b", + wantWeight: 21, + }, + { + name: "downstream indirect", + graph: []any{"p:compute:a -> p:glue:b -> p:glue:c -> p:glue:d -> p:compute:e"}, + src: "p:compute:a", + target: "p:compute:e", + id: "p:glue:c", + wantWeight: 15, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := enginetesting.NewTestSolution() + ctx.KB. + On("GetResourceTemplate", mock.MatchedBy(construct.ResourceId{Type: "compute"}.Matches)). + Return(&knowledgebase.ResourceTemplate{ + Classification: knowledgebase.Classification{ + Is: []string{"compute"}, + }, + }, nil) + ctx.KB. + On("GetResourceTemplate", mock.MatchedBy(construct.ResourceId{Type: "glue"}.Matches)). + Return(&knowledgebase.ResourceTemplate{}, nil) + ctx.KB. + On("GetEdgeTemplate", mock.Anything, mock.Anything). + Return(&knowledgebase.EdgeTemplate{}) + + ctx.LoadState(t, tt.graph...) + + resultGraph := graphtest.MakeGraph(t, construct.NewGraph(), tt.resultGraph...) + + undirected, err := BuildUndirectedGraph(ctx.DataflowGraph(), ctx.KnowledgeBase()) + require.NoError(t, err) + + src := graphtest.ParseId(t, tt.src) + target := graphtest.ParseId(t, tt.target) + id := graphtest.ParseId(t, tt.id) + + gotWeight, err := determineCandidateWeight(ctx, src, target, id, resultGraph, undirected) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantWeight, gotWeight) + }) + } +} + +func TestBuildUndirectedGraph(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + kb := &kbtesting.MockKB{} + kb.Mock. + On("GetResourceTemplate", mock.MatchedBy(construct.ResourceId{Type: "compute"}.Matches)). + Return(&knowledgebase.ResourceTemplate{ + Classification: knowledgebase.Classification{ + Is: []string{"compute"}, + }, + }, nil) + kb.Mock. + On("GetResourceTemplate", mock.MatchedBy(construct.ResourceId{Type: "glue"}.Matches)). + Return(&knowledgebase.ResourceTemplate{}, nil) + + graph := graphtest.MakeGraph(t, construct.NewGraph(), + "p:compute:a -> p:glue:b -> p:glue:c -> p:compute:d", + "p:compute:e", + ) + + got, err := BuildUndirectedGraph(graph, kb) + require.NoError(err) + + assert.False(got.Traits().IsDirected, "graph should be undirected") + assert.True(got.Traits().IsWeighted, "graph should be weighted") + + for _, f := range []func(construct.Graph) (int, error){construct.Graph.Order, construct.Graph.Size} { + want, err := f(graph) + require.NoError(err) + got, err := f(got) + require.NoError(err) + assert.Equal(want, got) + } + + getNodes := func(g construct.Graph) []construct.ResourceId { + adj, err := g.AdjacencyMap() + require.NoError(err) + nodes := make([]construct.ResourceId, 0, len(adj)) + for n := range adj { + nodes = append(nodes, n) + } + return nodes + } + assert.ElementsMatch(getNodes(graph), getNodes(got)) + + wantWeights := map[string]int{ + "p:compute:a -> p:glue:b": 1000, // compute -> unknown + "p:glue:b -> p:glue:c": 1, // unknown -> unknown + "p:glue:c -> p:compute:d": 1000, // unknown -> compute + } + for e, w := range wantWeights { + wantEdge := graphtest.ParseEdge(t, e) + gotEdge, err := got.Edge(wantEdge.Source, wantEdge.Target) + + if assert.NoError(err) { + assert.Equal(w, gotEdge.Properties.Weight) + } + } +} diff --git a/pkg/engine2/path_selection/path_expansion.go b/pkg/engine2/path_selection/path_expansion.go index 3f45d07c7..df2e6e7eb 100644 --- a/pkg/engine2/path_selection/path_expansion.go +++ b/pkg/engine2/path_selection/path_expansion.go @@ -313,6 +313,11 @@ func expandPath( return errs } + undirected, err := BuildUndirectedGraph(ctx.RawView(), ctx.KnowledgeBase()) + if err != nil { + return err + } + addCandidates := func(id construct.ResourceId, resource *construct.Resource, nerr error) error { matchIdx := matchesNonBoundary(id, nonBoundaryResources) if matchIdx < 0 { @@ -334,7 +339,7 @@ func expandPath( if _, ok := candidates[matchIdx][id]; !ok { candidates[matchIdx][id] = 0 } - weight, err := determineCandidateWeight(ctx, input.Dep.Source.ID, input.Dep.Target.ID, id, resultGraph) + weight, err := determineCandidateWeight(ctx, input.Dep.Source.ID, input.Dep.Target.ID, id, resultGraph, undirected) if err != nil { return errors.Join(nerr, err) } @@ -352,7 +357,7 @@ func expandPath( } // We need to add candidates which exist in our current result graph so we can reuse them. We do this in case // we have already performed expansions to ensure the namespaces are connected, etc - err := construct.WalkGraph(resultGraph, func(id construct.ResourceId, resource *construct.Resource, nerr error) error { + err = construct.WalkGraph(resultGraph, func(id construct.ResourceId, resource *construct.Resource, nerr error) error { return addCandidates(id, resource, nerr) }) if err != nil { diff --git a/pkg/engine2/path_selection/path_selection.go b/pkg/engine2/path_selection/path_selection.go index 9bee6855a..4bc4eb34d 100644 --- a/pkg/engine2/path_selection/path_selection.go +++ b/pkg/engine2/path_selection/path_selection.go @@ -3,7 +3,6 @@ package path_selection import ( "errors" "fmt" - "math/rand" "github.com/dominikbraun/graph" "github.com/klothoplatform/klotho/pkg/collectionutil" @@ -90,8 +89,10 @@ func BuildPathSelectionGraph( } var prevRes construct.ResourceId for i, res := range path { - id := res.Id() - id.Name = fmt.Sprintf("%s%s", PHANTOM_PREFIX, generateStringSuffix(5)) + id, err := makePhantom(tempGraph, res.Id()) + if err != nil { + return nil, err + } if i == 0 { id = dep.Source } else if i == len(path)-1 { @@ -150,14 +151,15 @@ func PathSatisfiesClassification( return true } -func generateStringSuffix(n int) string { - var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - b := make([]rune, n) - for i := range b { - b[i] = letterRunes[rand.Intn(len(letterRunes))] +func makePhantom(g construct.Graph, id construct.ResourceId) (construct.ResourceId, error) { + for suffix := 0; suffix < 1000; suffix++ { + candidate := id + candidate.Name = fmt.Sprintf("%s%d", PHANTOM_PREFIX, suffix) + if _, err := g.Vertex(candidate); errors.Is(err, graph.ErrVertexNotFound) { + return candidate, nil + } } - return string(b) - + return id, fmt.Errorf("exhausted suffixes for creating phantom for %s", id) } func calculateEdgeWeight( diff --git a/pkg/engine2/path_selection/path_selection_test.go b/pkg/engine2/path_selection/path_selection_test.go new file mode 100644 index 000000000..37a187602 --- /dev/null +++ b/pkg/engine2/path_selection/path_selection_test.go @@ -0,0 +1,114 @@ +package path_selection + +import ( + "fmt" + "testing" + + "github.com/dominikbraun/graph" + construct "github.com/klothoplatform/klotho/pkg/construct2" + "github.com/klothoplatform/klotho/pkg/construct2/graphtest" + knowledgebase "github.com/klothoplatform/klotho/pkg/knowledge_base2" + "github.com/klothoplatform/klotho/pkg/knowledge_base2/kbtesting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestBuildPathSelectionGraph(t *testing.T) { + addRes := func(kb *kbtesting.MockKB, s string, is ...string) { + r := graphtest.ParseId(t, s) + kb.On("GetResourceTemplate", mock.MatchedBy(r.Matches)). + Return(&knowledgebase.ResourceTemplate{ + Classification: knowledgebase.Classification{Is: is}, + }, nil) + } + addEdge := func(kb *kbtesting.MockKB, s string) { + e := graphtest.ParseEdge(t, s) + kb.On("GetEdgeTemplate", mock.MatchedBy(e.Source.Matches), mock.MatchedBy(e.Target.Matches)). + Return(&knowledgebase.EdgeTemplate{}) + } + type args struct { + dep string + kb func(t *testing.T, kb *kbtesting.MockKB) + classification string + } + tests := []struct { + name string + args args + want []any + wantWeights map[string]int + wantErr bool + }{ + { + name: "no edge", + args: args{ + dep: "p:t:a -> p:t:b", + kb: func(t *testing.T, kb *kbtesting.MockKB) { + addRes(kb, "p:t") + kb.On("AllPaths", mock.Anything, mock.Anything).Return([][]*knowledgebase.ResourceTemplate{}, nil) + }, + classification: "network", + }, + want: []any{"p:t:a", "p:t:b"}, + }, + { + name: "path through classification", + args: args{ + dep: "p:a:a -> p:c:c", + kb: func(t *testing.T, kb *kbtesting.MockKB) { + addRes(kb, "p:a") + addRes(kb, "p:b", "network") + addRes(kb, "p:c") + addEdge(kb, "p:a -> p:b") + addEdge(kb, "p:b -> p:c") + kb.On("AllPaths", mock.Anything, mock.Anything).Return([][]*knowledgebase.ResourceTemplate{ + { + {QualifiedTypeName: "p:a"}, + {QualifiedTypeName: "p:b"}, + {QualifiedTypeName: "p:c"}, + }, + }, nil) + }, + classification: "network", + }, + want: []any{"p:a:a -> p:b:phantom$0 -> p:c:c"}, + wantWeights: map[string]int{ + "p:a:a -> p:b:phantom$0": 109, + "p:b:phantom$0 -> p:c:c": 109, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dep := graphtest.ParseEdge(t, tt.args.dep) + + kb := &kbtesting.MockKB{} + kb.Test(t) + tt.args.kb(t, kb) + kb.On("GetEdgeTemplate", mock.Anything, mock.Anything). + Return((*knowledgebase.EdgeTemplate)(nil)) + + got, err := BuildPathSelectionGraph( + construct.SimpleEdge{Source: dep.Source, Target: dep.Target}, + kb, + tt.args.classification, + ) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + want := graphtest.MakeGraph(t, construct.NewGraph(), tt.want...) + wantS, _ := construct.String(want) + fmt.Println(wantS) + for s, ww := range tt.wantWeights { + e := graphtest.ParseEdge(t, s) + require.NoError(t, want.UpdateEdge(e.Source, e.Target, graph.EdgeWeight(ww))) + } + graphtest.AssertGraphEqual(t, want, got, "") + + assert.True(t, got.Traits().IsWeighted, "not weighted: %+v", got.Traits()) + }) + } +}