From 58d73f6c52504157f1f11393e761f25b73758d09 Mon Sep 17 00:00:00 2001 From: DavidSeptimus-Klotho Date: Fri, 19 Jul 2024 17:08:58 -0600 Subject: [PATCH] Adds support for structured inputs and for-each rules --- .gitignore | 4 + pkg/construct/dot.go | 2 +- pkg/construct/graph_update.go | 4 +- pkg/construct/properties.go | 141 +-- pkg/construct/properties_test.go | 69 +- pkg/engine/dot.go | 2 +- pkg/engine/operational_eval/eval.go | 5 +- .../operational_rule/operational_rule.go | 2 +- pkg/engine/solution.go | 2 +- pkg/infra/iac/prop_refs.go | 2 +- pkg/k2/constructs/binding.go | 91 +- pkg/k2/constructs/construct.go | 186 +--- pkg/k2/constructs/construct_evaluator.go | 994 +++++------------- pkg/k2/constructs/construct_evaluator_test.go | 232 ++-- pkg/k2/constructs/construct_marshaller.go | 32 +- .../constructs/construct_marshaller_test.go | 117 +-- pkg/k2/constructs/construct_test.go | 216 ++-- pkg/k2/constructs/constructs.go | 63 +- pkg/k2/constructs/dynamic_value.go | 322 ++++++ pkg/k2/constructs/import_resources.go | 215 ++++ pkg/k2/constructs/input_resolver.go | 113 -- pkg/k2/constructs/interpolation.go | 413 ++++++++ .../constructs/template/binding_template.go | 71 ++ .../{ => template}/construct_template.go | 229 ++-- .../{ => template}/construct_template_test.go | 89 +- .../template/inputs/properties_template.go | 437 ++++++++ .../inputs/properties_template_test.go | 251 +++++ pkg/k2/constructs/template/properties.go | 68 ++ .../template/properties/any_property.go | 108 ++ .../template/properties/any_property_test.go | 230 ++++ .../template/properties/bool_property.go | 101 ++ .../template/properties/bool_property_test.go | 122 +++ .../template/properties/construct_property.go | 176 ++++ .../template/properties/float_property.go | 122 +++ .../properties/float_property_test.go | 251 +++++ .../template/properties/int_property.go | 126 +++ .../template/properties/int_property_test.go | 243 +++++ .../template/properties/key_value_list.go | 229 ++++ .../properties/key_value_list_test.go | 379 +++++++ .../template/properties/list_property.go | 264 +++++ .../template/properties/list_property_test.go | 283 +++++ .../template/properties/map_property.go | 263 +++++ .../template/properties/map_propery_test.go | 359 +++++++ .../template/properties/path_property.go | 161 +++ .../template/properties/properties.go | 138 +++ .../template/properties/set_property.go | 230 ++++ .../template/properties/set_property_test.go | 328 ++++++ .../properties/shared_property_fields.go | 41 + .../template/properties/string_property.go | 122 +++ .../properties/string_property_test.go | 181 ++++ .../template/property/construct_type.go | 67 ++ .../template/property/construct_type_test.go | 50 + .../template/property/interfaces.go | 71 ++ .../template/property/properties.go | 151 +++ .../template/property/property_details.go | 21 + .../template/property/sanitization.go | 94 ++ pkg/k2/constructs/template/property/util.go | 19 + pkg/k2/constructs/template/resource_ref.go | 83 ++ .../{ => template}/template_loader.go | 35 +- .../{ => template}/templates/aws/api/api.yaml | 0 .../api/bindings/to_klotho.aws.Container.yaml | 72 ++ .../api/bindings/to_klotho.aws.Function.yaml | 3 +- .../templates/aws/bucket/bucket.yaml | 15 +- .../bindings/from_klotho.aws.Api.yaml | 0 .../bindings/to_klotho.aws.Bucket.yaml | 6 +- .../bindings/to_klotho.aws.DynamoDB.yaml | 0 .../bindings/to_klotho.aws.LoadBalancer.yaml | 22 +- .../bindings/to_klotho.aws.Postgres.yaml | 0 .../templates/aws/container/container.yaml | 67 +- .../templates/aws/dynamodb/dynamodb.yaml | 47 +- .../bindings/to_klotho.aws.Bucket.yaml | 4 +- .../bindings/to_klotho.aws.DynamoDB.yaml | 0 .../bindings/to_klotho.aws.Postgres.yaml | 0 .../templates/aws/fastapi/fastapi.yaml | 54 +- .../bindings/to_klotho.aws.Bucket.yaml | 15 + .../bindings/to_klotho.aws.DynamoDB.yaml | 0 .../bindings/to_klotho.aws.Postgres.yaml | 0 .../templates/aws/function/function.yaml | 175 ++- .../templates/aws/network/network.yaml | 0 .../bindings/from_klotho.aws.Container.yaml | 0 .../bindings/from_klotho.aws.FastAPI.yaml | 0 .../templates/aws/postgres/postgres.yaml | 39 +- .../api/bindings/to_klotho.aws.Container.yaml | 23 - .../bindings/to_klotho.aws.Bucket.yaml | 21 - pkg/k2/ir_samples/container.yaml | 4 +- pkg/k2/ir_samples/testenv.yaml | 2 +- pkg/k2/k2_test.go | 14 +- pkg/k2/language_host/language_host.go | 1 - .../python/klothosdk/src/klotho/aws/api.py | 27 +- .../python/klothosdk/src/klotho/construct.py | 6 + .../python/klothosdk/src/klotho/output.py | 5 +- .../python/samples/starter/infra-api.py | 10 +- pkg/k2/model/urn.go | 4 + pkg/k2/orchestration/up_orchestrator.go | 9 +- .../bucket_ro/my-bucket.engine_input.yaml | 5 - .../bucket_ro/my-bucket.resources.yaml | 1 - .../bucket_ro/my-container.engine_input.yaml | 1 - .../bucket_ro/my-container.resources.yaml | 1 - .../dynamo/my-dynamodb.resources.yaml | 4 +- .../function/docker-func.engine_input.yaml | 10 - .../function/docker-func.resources.yaml | 3 +- pkg/k2/testdata/function/infra.py | 3 +- .../function/my-api.engine_input.yaml | 3 +- .../testdata/function/my-api.resources.yaml | 3 +- .../function/my-bucket.engine_input.yaml | 5 - .../function/my-bucket.resources.yaml | 1 - .../function/zip-func.engine_input.yaml | 1 - .../testdata/function/zip-func.resources.yaml | 1 - pkg/k2/testdata/simple_api/infra.py | 3 +- .../simple_api/my-api.engine_input.yaml | 21 +- pkg/k2/testdata/simple_api/my-api.index.ts | 28 +- .../testdata/simple_api/my-api.resources.yaml | 32 +- .../simple_api/my-container.engine_input.yaml | 2 +- pkg/knowledgebase/kb.go | 2 +- pkg/reflectutil/map.go | 25 + pkg/{k2 => }/reflectutil/reflectutil.go | 67 +- pkg/{k2 => }/reflectutil/reflectutil_test.go | 63 +- pkg/templateutils/funcs.go | 26 +- pkg/templateutils/funcs_test.go | 386 +++++++ 119 files changed, 8737 insertions(+), 2025 deletions(-) create mode 100644 pkg/k2/constructs/dynamic_value.go create mode 100644 pkg/k2/constructs/import_resources.go delete mode 100644 pkg/k2/constructs/input_resolver.go create mode 100644 pkg/k2/constructs/interpolation.go create mode 100644 pkg/k2/constructs/template/binding_template.go rename pkg/k2/constructs/{ => template}/construct_template.go (54%) rename pkg/k2/constructs/{ => template}/construct_template_test.go (82%) create mode 100644 pkg/k2/constructs/template/inputs/properties_template.go create mode 100644 pkg/k2/constructs/template/inputs/properties_template_test.go create mode 100644 pkg/k2/constructs/template/properties.go create mode 100644 pkg/k2/constructs/template/properties/any_property.go create mode 100644 pkg/k2/constructs/template/properties/any_property_test.go create mode 100644 pkg/k2/constructs/template/properties/bool_property.go create mode 100644 pkg/k2/constructs/template/properties/bool_property_test.go create mode 100644 pkg/k2/constructs/template/properties/construct_property.go create mode 100644 pkg/k2/constructs/template/properties/float_property.go create mode 100644 pkg/k2/constructs/template/properties/float_property_test.go create mode 100644 pkg/k2/constructs/template/properties/int_property.go create mode 100644 pkg/k2/constructs/template/properties/int_property_test.go create mode 100644 pkg/k2/constructs/template/properties/key_value_list.go create mode 100644 pkg/k2/constructs/template/properties/key_value_list_test.go create mode 100644 pkg/k2/constructs/template/properties/list_property.go create mode 100644 pkg/k2/constructs/template/properties/list_property_test.go create mode 100644 pkg/k2/constructs/template/properties/map_property.go create mode 100644 pkg/k2/constructs/template/properties/map_propery_test.go create mode 100644 pkg/k2/constructs/template/properties/path_property.go create mode 100644 pkg/k2/constructs/template/properties/properties.go create mode 100644 pkg/k2/constructs/template/properties/set_property.go create mode 100644 pkg/k2/constructs/template/properties/set_property_test.go create mode 100644 pkg/k2/constructs/template/properties/shared_property_fields.go create mode 100644 pkg/k2/constructs/template/properties/string_property.go create mode 100644 pkg/k2/constructs/template/properties/string_property_test.go create mode 100644 pkg/k2/constructs/template/property/construct_type.go create mode 100644 pkg/k2/constructs/template/property/construct_type_test.go create mode 100644 pkg/k2/constructs/template/property/interfaces.go create mode 100644 pkg/k2/constructs/template/property/properties.go create mode 100644 pkg/k2/constructs/template/property/property_details.go create mode 100644 pkg/k2/constructs/template/property/sanitization.go create mode 100644 pkg/k2/constructs/template/property/util.go create mode 100644 pkg/k2/constructs/template/resource_ref.go rename pkg/k2/constructs/{ => template}/template_loader.go (74%) rename pkg/k2/constructs/{ => template}/templates/aws/api/api.yaml (100%) create mode 100644 pkg/k2/constructs/template/templates/aws/api/bindings/to_klotho.aws.Container.yaml rename pkg/k2/constructs/{ => template}/templates/aws/api/bindings/to_klotho.aws.Function.yaml (95%) rename pkg/k2/constructs/{ => template}/templates/aws/bucket/bucket.yaml (79%) rename pkg/k2/constructs/{ => template}/templates/aws/container/bindings/from_klotho.aws.Api.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/container/bindings/to_klotho.aws.Bucket.yaml (90%) rename pkg/k2/constructs/{ => template}/templates/aws/container/bindings/to_klotho.aws.DynamoDB.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/container/bindings/to_klotho.aws.LoadBalancer.yaml (84%) rename pkg/k2/constructs/{ => template}/templates/aws/container/bindings/to_klotho.aws.Postgres.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/container/container.yaml (73%) rename pkg/k2/constructs/{ => template}/templates/aws/dynamodb/dynamodb.yaml (77%) rename pkg/k2/constructs/{ => template}/templates/aws/fastapi/bindings/to_klotho.aws.Bucket.yaml (93%) rename pkg/k2/constructs/{ => template}/templates/aws/fastapi/bindings/to_klotho.aws.DynamoDB.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/fastapi/bindings/to_klotho.aws.Postgres.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/fastapi/fastapi.yaml (86%) create mode 100644 pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml rename pkg/k2/constructs/{ => template}/templates/aws/function/bindings/to_klotho.aws.DynamoDB.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/function/bindings/to_klotho.aws.Postgres.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/function/function.yaml (50%) rename pkg/k2/constructs/{ => template}/templates/aws/network/network.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/postgres/bindings/from_klotho.aws.Container.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/postgres/bindings/from_klotho.aws.FastAPI.yaml (100%) rename pkg/k2/constructs/{ => template}/templates/aws/postgres/postgres.yaml (85%) delete mode 100644 pkg/k2/constructs/templates/aws/api/bindings/to_klotho.aws.Container.yaml delete mode 100644 pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml create mode 100644 pkg/reflectutil/map.go rename pkg/{k2 => }/reflectutil/reflectutil.go (74%) rename pkg/{k2 => }/reflectutil/reflectutil_test.go (92%) create mode 100644 pkg/templateutils/funcs_test.go diff --git a/.gitignore b/.gitignore index 18ccfe329..bf3f393c7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,10 @@ keyvalue.js.tmpl #OS X things *.DS_Store* +# Jetbrains +.idea/ +*.iml + node_modules/ /*.yaml /_samples/ diff --git a/pkg/construct/dot.go b/pkg/construct/dot.go index b24cc7286..2b5f98d82 100644 --- a/pkg/construct/dot.go +++ b/pkg/construct/dot.go @@ -22,7 +22,7 @@ func dotAttributes(r *Resource) map[string]string { func dotEdgeAttributes(e ResourceEdge) map[string]string { a := make(map[string]string) _ = e.Source.WalkProperties(func(path PropertyPath, nerr error) error { - v := path.Get() + v, _ := path.Get() if v == e.Target.ID { a["label"] = path.String() return StopWalk diff --git a/pkg/construct/graph_update.go b/pkg/construct/graph_update.go index fce8d8cef..17d9751fc 100644 --- a/pkg/construct/graph_update.go +++ b/pkg/construct/graph_update.go @@ -33,7 +33,7 @@ func ReplaceResource(g Graph, oldId ResourceId, newRes *Resource) error { } updateId := func(path PropertyPathItem) error { - itemVal := path.Get() + itemVal, _ := path.Get() if itemId, ok := itemVal.(ResourceId); ok && itemId == oldId { return path.Set(newRes.ID) @@ -112,7 +112,7 @@ func RemoveResource(g Graph, id ResourceId) error { } removeId := func(path PropertyPathItem) (bool, error) { - itemVal := path.Get() + itemVal, _ := path.Get() itemId, ok := itemVal.(ResourceId) if ok && itemId == id { return true, path.Remove(nil) diff --git a/pkg/construct/properties.go b/pkg/construct/properties.go index f7b26b766..b1ac87243 100644 --- a/pkg/construct/properties.go +++ b/pkg/construct/properties.go @@ -1,12 +1,15 @@ package construct import ( + "errors" "fmt" "reflect" "sort" "strconv" "strings" + "github.com/klothoplatform/klotho/pkg/reflectutil" + "github.com/klothoplatform/klotho/pkg/set" "github.com/klothoplatform/klotho/pkg/yaml_util" ) @@ -25,14 +28,22 @@ func (r *Resource) SetProperty(pathStr string, value any) error { } // GetProperty is a wrapper around [PropertyPath.Get] for convenience. +// It returns ErrPropertyDoesNotExist if the property does not exist. func (r *Resource) GetProperty(pathStr string) (any, error) { path, err := r.PropertyPath(pathStr) if err != nil { return nil, err } - return path.Get(), nil + if value, ok := path.Get(); ok { + return value, nil + } + // Backwards compatibility: if the property does not exist, return nil instead of an error. + return nil, nil } +// ErrPropertyDoesNotExist is returned when a property does not exist. +var ErrPropertyDoesNotExist = errors.New("property does not exist") + // AppendProperty is a wrapper around [PropertyPath.Append] for convenience. func (r *Resource) AppendProperty(pathStr string, value any) error { path, err := r.PropertyPath(pathStr) @@ -51,6 +62,13 @@ func (r *Resource) RemoveProperty(pathStr string, value any) error { return path.Remove(value) } +func (r *Resource) PropertyPath(pathStr string) (PropertyPath, error) { + if r.Properties == nil { + r.Properties = Properties{} + } + return r.Properties.PropertyPath(pathStr) +} + func (p Properties) Equals(other any) (equal bool) { otherProps, ok := other.(Properties) if !ok { @@ -67,8 +85,8 @@ func (p Properties) Equals(other any) (equal bool) { equal = false return StopWalk } - v := path.Get() - otherV := otherPath.Get() + v, _ := path.Get() + otherV, _ := otherPath.Get() if v == nil || otherV == nil { equal = v == otherV @@ -93,9 +111,44 @@ func (p Properties) Equals(other any) (equal bool) { return equal } +func (p Properties) SetProperty(pathStr string, value any) error { + path, err := p.PropertyPath(pathStr) + if err != nil { + return err + } + return path.Set(value) +} + +func (p *Properties) GetProperty(pathStr string) (any, error) { + path, err := p.PropertyPath(pathStr) + if err != nil { + return nil, err + } + if value, ok := path.Get(); ok { + return value, nil + } + return nil, ErrPropertyDoesNotExist +} + +func (p Properties) AppendProperty(pathStr string, value any) error { + path, err := p.PropertyPath(pathStr) + if err != nil { + return err + } + return path.Append(value) +} + +func (p Properties) RemoveProperty(pathStr string, value any) error { + path, err := p.PropertyPath(pathStr) + if err != nil { + return err + } + return path.Remove(value) +} + type ( PropertyPathItem interface { - Get() any + Get() (value any, ok bool) Set(value any) error Remove(value any) error Append(value any) error @@ -126,54 +179,10 @@ type ( } ) -func splitPath(path string) []string { - var parts []string - bracket := 0 - lastPartIdx := 0 - for i := 0; i < len(path); i++ { - switch path[i] { - case '.': - if bracket == 0 { - if i > lastPartIdx { - parts = append(parts, path[lastPartIdx:i]) - } - lastPartIdx = i - } - - case '[': - if bracket == 0 { - if i > lastPartIdx { - parts = append(parts, path[lastPartIdx:i]) - } - lastPartIdx = i - } - bracket++ - - case ']': - bracket-- - if bracket == 0 { - parts = append(parts, path[lastPartIdx:i+1]) - lastPartIdx = i + 1 - } - } - if i == len(path)-1 && lastPartIdx <= i { - parts = append(parts, path[lastPartIdx:]) - } - } - return parts -} - -func (r *Resource) PropertyPath(pathStr string) (PropertyPath, error) { - if r.Properties == nil { - r.Properties = Properties{} - } - return r.Properties.PropertyPath(pathStr) -} - // PropertyPath interprets a string path to index (potentially deeply) into [Resource.Properties] // which can be used to get, set, append, or remove values. func (p Properties) PropertyPath(pathStr string) (PropertyPath, error) { - pathParts := splitPath(pathStr) + pathParts := reflectutil.SplitPath(pathStr) if len(pathParts) == 0 { return nil, fmt.Errorf("empty path") } @@ -512,15 +521,15 @@ func (i *mapValuePathItem) Remove(value any) (err error) { return nil } -func (i *mapValuePathItem) Get() any { +func (i *mapValuePathItem) Get() (any, bool) { if !i.m.IsValid() { - return nil + return nil, false } v := i.m.MapIndex(i.key) if !v.IsValid() { - return nil + return nil, false } - return v.Interface() + return v.Interface(), true } func (i *mapValuePathItem) parent() PropertyPathItem { @@ -531,8 +540,8 @@ func (i *mapValuePathItem) Key() PropertyPathItem { return (*mapKeyPathItem)(i) } -func (i *mapKeyPathItem) Get() any { - return i.key.Interface() +func (i *mapKeyPathItem) Get() (any, bool) { + return i.key.Interface(), true } func (i *mapKeyPathItem) Set(value any) (err error) { @@ -605,8 +614,14 @@ func (i *arrayIndexPathItem) Remove(value any) (err error) { return nil } -func (i *arrayIndexPathItem) Get() any { - return i.a.Index(i.index).Interface() +func (i *arrayIndexPathItem) Get() (any, bool) { + if !i.a.IsValid() || !reflectutil.IsAnyOf(reflectutil.GetConcreteElement(i.a), reflect.Slice, reflect.Array) { + return nil, false + } + if i.a.Len() <= i.index { + return nil, false + } + return i.a.Index(i.index).Interface(), true } func (i *arrayIndexPathItem) parent() PropertyPathItem { @@ -624,14 +639,14 @@ func (i PropertyPath) Append(value any) error { } // Remove removes the value at this path item. If value is nil, it is interpreted -// to remove the item itself. Non-nil value'd remove is only supported on array items, to +// to remove the item itself. Non-nil valued remove is only supported on array items, to // remove a value from the array. func (i PropertyPath) Remove(value any) error { return i[len(i)-1].Remove(value) } // Get returns the value at this path item. -func (i PropertyPath) Get() any { +func (i PropertyPath) Get() (any, bool) { return i[len(i)-1].Get() } @@ -700,7 +715,7 @@ func (r *Resource) WalkProperties(fn WalkPropertiesFunc) error { } // WalkProperties walks the properties of the resource, calling fn for each property. If fn returns -// SkipProperty, the property and its decendants (if a map or array type) is skipped. If fn returns +// SkipProperty, the property and its descendants (if a map or array type) are skipped. If fn returns // StopWalk, the walk is stopped. // NOTE: does not walk over the _keys_ of any maps, only values. func (p Properties) WalkProperties(fn WalkPropertiesFunc) error { @@ -733,7 +748,11 @@ func (p Properties) WalkProperties(fn WalkPropertiesFunc) error { } added := make(set.Set[string]) - v := reflect.ValueOf(current.Get()) + rv, ok := current.Get() + if !ok { + continue + } + v := reflect.ValueOf(rv) switch v.Kind() { case reflect.Map: keys, err := mapKeys(v) diff --git a/pkg/construct/properties_test.go b/pkg/construct/properties_test.go index 64e4ee357..50ee8fbb4 100644 --- a/pkg/construct/properties_test.go +++ b/pkg/construct/properties_test.go @@ -4,6 +4,8 @@ import ( "strings" "testing" + "github.com/klothoplatform/klotho/pkg/reflectutil" + "github.com/klothoplatform/klotho/pkg/set" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -60,7 +62,7 @@ func Test_splitPath(t *testing.T) { t.Run(tt.name, func(t *testing.T) { assert := assert.New(t) - got := splitPath(tt.path) + got := reflectutil.SplitPath(tt.path) assert.Equal(tt.want, got) }) } @@ -149,7 +151,9 @@ func TestResource_PropertyPath(t *testing.T) { if !assert.NoError(err) { return } - assert.Equal(tt.want, path.Get()) + v, ok := path.Get() + assert.True(ok) + assert.Equal(tt.want, v) // Test the last item's itemToPath instead of the path's Parts // because this will test both functions (itemToPath and Parts) @@ -185,49 +189,68 @@ func TestResource_PropertyPath_ops(t *testing.T) { } foo := path("A.foo") - assert.Equal("bar", foo.Get()) + v, ok := foo.Get() + assert.True(ok) + assert.Equal("bar", v) if assert.NoError(foo.Set("baz")) { - assert.Equal("baz", foo.Get()) + v, ok := foo.Get() + assert.True(ok) + assert.Equal("baz", v) } assert.Error(foo.Append("value")) if assert.NoError(foo.Remove(nil)) { assert.Nil(foo.Get()) - m := path("A").Get().(map[string]any) + v, ok := path("A").Get() + assert.True(ok) + m := v.(map[string]any) assert.NotContains(m, "foo") } arr := path("A.array") if assert.NoError(arr.Append("cat")) { - assert.Equal([]any{"fox", "bat", "dog", "cat"}, arr.Get()) + v, ok := arr.Get() + assert.True(ok) + assert.Equal([]any{"fox", "bat", "dog", "cat"}, v) } if assert.NoError(arr.Remove("bat")) { - assert.Equal([]any{"fox", "dog", "cat"}, arr.Get()) + v, ok := arr.Get() + assert.True(ok) + assert.Equal([]any{"fox", "dog", "cat"}, v) } fox := path("A.array[0]") - assert.Equal("fox", fox.Get()) + v, _ = fox.Get() + assert.Equal("fox", v) if assert.NoError(fox.Set("wolf")) { - assert.Equal("wolf", fox.Get()) - assert.Equal([]any{"wolf", "dog", "cat"}, arr.Get()) + v, _ = fox.Get() + assert.Equal("wolf", v) + v, _ = arr.Get() + assert.Equal([]any{"wolf", "dog", "cat"}, v) } if assert.NoError(fox.Remove(nil)) { - assert.Equal([]any{"dog", "cat"}, arr.Get()) - assert.Equal("dog", fox.Get()) // [0] now points to "dog" + v, _ = arr.Get() + assert.Equal([]any{"dog", "cat"}, v) + v, _ = fox.Get() + assert.Equal("dog", v) // [0] now points to "dog" } two := path("B[0][1]") - assert.Equal(2, two.Get()) + v, _ = two.Get() + assert.Equal(2, v) if assert.NoError(two.Remove(nil)) { - assert.Equal([]any{1, 3}, path("B[0]").Get()) + v, _ = path("B[0]").Get() + assert.Equal([]any{1, 3}, v) } c := path("C") if assert.NoError(c.Append(map[string]string{"hello": "world"})) { + v, ok := c.Get() + assert.True(ok) assert.Equal(map[string]any{ "x": "y", "hello": "world", - }, c.Get()) + }, v) } d := path("D") @@ -239,18 +262,24 @@ func TestResource_PropertyPath_ops(t *testing.T) { e := path("E") if assert.NoError(e.Set([]string{"one", "two"})) { assert.NoError(e.Append([]string{"three", "four"})) - assert.Equal([]string{"one", "two", "three", "four"}, e.Get()) + v, ok := e.Get() + assert.True(ok) + assert.Equal([]string{"one", "two", "three", "four"}, v) } tmp := path("temp") if assert.NoError(tmp.Append("test")) { - assert.Equal([]string{"test"}, tmp.Get()) + v, ok := tmp.Get() + assert.True(ok) + assert.Equal([]string{"test"}, v) assert.NoError(tmp.Remove(nil)) } assert.Nil(tmp.Get()) if assert.NoError(tmp.Append(map[string]string{"hello": "world"})) { - assert.Equal(map[string]string{"hello": "world"}, tmp.Get()) + v, ok := tmp.Get() + assert.True(ok) + assert.Equal(map[string]string{"hello": "world"}, v) assert.NoError(tmp.Remove(nil)) } @@ -258,7 +287,9 @@ func TestResource_PropertyPath_ops(t *testing.T) { assert.Nil(nested.Get()) assert.Nil(path("deeply").Get()) if assert.NoError(nested.Set("test")) { - assert.Equal(map[string]interface{}(map[string]interface{}{"nested": map[string]interface{}{"value": "test"}}), path("deeply").Get()) + v, ok := path("deeply").Get() + assert.True(ok) + assert.Equal(map[string]interface{}{"nested": map[string]interface{}{"value": "test"}}, v) } } diff --git a/pkg/engine/dot.go b/pkg/engine/dot.go index 9599ff15c..2ab7d8790 100644 --- a/pkg/engine/dot.go +++ b/pkg/engine/dot.go @@ -32,7 +32,7 @@ func dotAttributes(kb knowledgebase.TemplateKB, r *construct.Resource, props gra func dotEdgeAttributes(kb knowledgebase.TemplateKB, g construct.Graph, e construct.ResourceEdge) map[string]string { a := make(map[string]string) _ = e.Source.WalkProperties(func(path construct.PropertyPath, nerr error) error { - v := path.Get() + v, _ := path.Get() if v == e.Target.ID { a["label"] = path.String() return construct.StopWalk diff --git a/pkg/engine/operational_eval/eval.go b/pkg/engine/operational_eval/eval.go index 90ae96c3b..d6745e87d 100644 --- a/pkg/engine/operational_eval/eval.go +++ b/pkg/engine/operational_eval/eval.go @@ -268,7 +268,10 @@ func (eval *Evaluator) cleanupPropertiesSubVertices(ref construct.PropertyRef, r if err == nil { // if the paths parent still exists then we know we will end up evaluating the vertex and should not remove it parentIndex := len(path) - 2 - if parentIndex < 0 || path[parentIndex].Get() != nil { + if parentIndex < 0 { + continue + } + if parent, ok := path[parentIndex].Get(); ok && parent != nil { continue } } diff --git a/pkg/engine/operational_rule/operational_rule.go b/pkg/engine/operational_rule/operational_rule.go index 66450e0bf..574023cf2 100644 --- a/pkg/engine/operational_rule/operational_rule.go +++ b/pkg/engine/operational_rule/operational_rule.go @@ -132,7 +132,7 @@ func (ctx OperationalRuleContext) CleanProperty(step knowledgebase.OperationalSt if err != nil { return err } - prop := path.Get() + prop, _ := path.Get() if prop == nil { return nil } diff --git a/pkg/engine/solution.go b/pkg/engine/solution.go index 05398ad90..25fdfb20e 100644 --- a/pkg/engine/solution.go +++ b/pkg/engine/solution.go @@ -141,7 +141,7 @@ func (s *engineSolution) LoadGraph(graph construct.Graph) error { // ensure any deployment dependencies due to properties are in place return construct.WalkGraph(s.RawView(), func(id construct.ResourceId, resource *construct.Resource, nerr error) error { return errors.Join(nerr, resource.WalkProperties(func(path construct.PropertyPath, werr error) error { - prop := path.Get() + prop, _ := path.Get() err := solution.AddDeploymentDependenciesFromVal(s, resource, prop) return errors.Join(werr, err) })) diff --git a/pkg/infra/iac/prop_refs.go b/pkg/infra/iac/prop_refs.go index cc878a4fb..a5f943606 100644 --- a/pkg/infra/iac/prop_refs.go +++ b/pkg/infra/iac/prop_refs.go @@ -66,7 +66,7 @@ func (tc *TemplatesCompiler) PropertyRefValue(ref construct.PropertyRef) (any, e return nil, err } if path != nil { - val := path.Get() + val, _ := path.Get() if val == nil { return nil, fmt.Errorf("property ref %s is nil", ref) } diff --git a/pkg/k2/constructs/binding.go b/pkg/k2/constructs/binding.go index e78d656fd..7dbf903d9 100644 --- a/pkg/k2/constructs/binding.go +++ b/pkg/k2/constructs/binding.go @@ -1,9 +1,14 @@ package constructs import ( + "errors" + "fmt" "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" "github.com/klothoplatform/klotho/pkg/k2/model" "go.uber.org/zap" + "io/fs" ) type ( @@ -18,7 +23,7 @@ type ( From *Construct To *Construct Priority int - BindingTemplate BindingTemplate + BindingTemplate template.BindingTemplate Meta map[string]any Inputs construct.Properties Resources map[string]*Resource @@ -29,16 +34,19 @@ type ( } ) -func (b *Binding) GetInput(name string) (val any, ok bool) { - val, ok = b.Inputs[name] - return val, ok +func (b *Binding) GetInputs() construct.Properties { + return b.Inputs } -func (b *Binding) GetTemplateResourcesIterator() Iterator[string, ResourceTemplate] { +func (b *Binding) GetInputValue(name string) (value any, err error) { + return b.Inputs.GetProperty(name) +} + +func (b *Binding) GetTemplateResourcesIterator() template.Iterator[string, template.ResourceTemplate] { return b.BindingTemplate.ResourcesIterator() } -func (b *Binding) GetTemplateEdges() []EdgeTemplate { +func (b *Binding) GetTemplateEdges() []template.EdgeTemplate { return b.BindingTemplate.Edges } @@ -63,11 +71,11 @@ func (b *Binding) GetResources() map[string]*Resource { return b.Resources } -func (b *Binding) GetInputRules() []InputRuleTemplate { +func (b *Binding) GetInputRules() []template.InputRuleTemplate { return b.BindingTemplate.InputRules } -func (b *Binding) GetTemplateOutputs() map[string]OutputTemplate { +func (b *Binding) GetTemplateOutputs() map[string]template.OutputTemplate { return b.BindingTemplate.Outputs } @@ -79,10 +87,6 @@ func (b *Binding) DeclareOutput(key string, declaration OutputDeclaration) { b.OutputDeclarations[key] = declaration } -func (b *Binding) GetTemplateInputs() map[string]InputTemplate { - return b.BindingTemplate.Inputs -} - func (b *Binding) GetURN() model.URN { if b.Owner == nil { return model.URN{} @@ -92,13 +96,13 @@ func (b *Binding) GetURN() model.URN { func (b *Binding) String() string { e := Edge{ - From: ResourceRef{ConstructURN: b.From.URN}, - To: ResourceRef{ConstructURN: b.To.URN}, + From: template.ResourceRef{ConstructURN: b.From.URN}, + To: template.ResourceRef{ConstructURN: b.To.URN}, } return e.String() } -func (b *Binding) GetPropertySource() *PropertySource { +func (b *Binding) GetPropertySource() *template.PropertySource { ps := map[string]any{ "inputs": b.Inputs, "resources": b.Resources, @@ -124,47 +128,50 @@ func (b *Binding) GetPropertySource() *PropertySource { "outputs": b.To.Outputs, } } - return NewPropertySource(ps) + return template.NewPropertySource(ps) } -// newBinding creates a new Binding instance -// owner: the construct that owns the binding -// from: the construct that is the source of the binding -// to: the construct that is the target of the binding -// inputs: the inputs to the binding (default values will be populated for any missing inputs) +func (b *Binding) GetConstruct() *Construct { + return b.Owner +} + +// newBinding initializes a new binding instance using the template associated with the owner construct // returns: the new binding instance or an error if one occurred -func (ce *ConstructEvaluator) newBinding(owner, from, to model.URN) (*Binding, error) { - ownerTemplateId, err := ParseConstructTemplateId(owner.Subtype) +func (ce *ConstructEvaluator) newBinding(owner model.URN, d BindingDeclaration) (*Binding, error) { + ownerTemplateId, err := property.ParseConstructType(owner.Subtype) if err != nil { return nil, err } - fromTemplateId, err := ParseConstructTemplateId(from.Subtype) + fromTemplateId, err := property.ParseConstructType(d.From.Subtype) if err != nil { return nil, err } - toTemplateId, err := ParseConstructTemplateId(to.Subtype) + toTemplateId, err := property.ParseConstructType(d.To.Subtype) if err != nil { return nil, err } oc, _ := ce.Constructs.Get(owner) - fc, _ := ce.Constructs.Get(from) - tc, _ := ce.Constructs.Get(to) + fc, _ := ce.Constructs.Get(d.From) + tc, _ := ce.Constructs.Get(d.To) - bt, err := loadBindingTemplate(ownerTemplateId, fromTemplateId, toTemplateId) - if err != nil { + bt, err := template.LoadBindingTemplate(ownerTemplateId, fromTemplateId, toTemplateId) + var pathError *fs.PathError + if errors.As(err, &pathError) { zap.S().Debugf("template not found for binding %s -> %s -> %s", ownerTemplateId, fromTemplateId, toTemplateId) - bt = BindingTemplate{ + bt = template.BindingTemplate{ From: fromTemplateId, To: toTemplateId, Priority: 0, - Inputs: make(map[string]InputTemplate), - Outputs: make(map[string]OutputTemplate), - Resources: make(map[string]ResourceTemplate), + Inputs: template.NewProperties(nil), + Outputs: make(map[string]template.OutputTemplate), + Resources: make(map[string]template.ResourceTemplate), } + } else if err != nil { + return nil, fmt.Errorf("failed to load binding template %s -> %s -> %s: %w", ownerTemplateId.String(), fromTemplateId.String(), toTemplateId.String(), err) } - return &Binding{ + b := &Binding{ Owner: oc, From: fc, To: tc, @@ -177,5 +184,19 @@ func (ce *ConstructEvaluator) newBinding(owner, from, to model.URN) (*Binding, e OutputDeclarations: make(map[string]OutputDeclaration), Outputs: make(map[string]any), InitialGraph: construct.NewGraph(), - }, nil + } + + inputs, err := ce.convertInputs(d.Inputs) + if err != nil { + return nil, fmt.Errorf("invalid inputs for binding %s -> %s: %w", d.From, d.To, err) + } + err = ce.initializeInputs(b, inputs) + if err != nil { + return nil, fmt.Errorf("input initialization failed for binding %s -> %s: %w", d.From, d.To, err) + } + return b, nil +} + +func (b *Binding) ForEachInput(f func(property.Property) error) error { + return b.BindingTemplate.ForEachInput(b.Inputs, f) } diff --git a/pkg/k2/constructs/construct.go b/pkg/k2/constructs/construct.go index 3a5f47769..6f83049a7 100644 --- a/pkg/k2/constructs/construct.go +++ b/pkg/k2/constructs/construct.go @@ -3,21 +3,21 @@ package constructs import ( "errors" "fmt" - + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + inputs2 "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" "sort" "github.com/klothoplatform/klotho/pkg/construct" "github.com/klothoplatform/klotho/pkg/engine/solution" "github.com/klothoplatform/klotho/pkg/k2/model" - "go.uber.org/zap" ) type ( Construct struct { URN model.URN - ConstructTemplate ConstructTemplate + ConstructTemplate template.ConstructTemplate Meta map[string]any - Inputs map[string]any + Inputs construct.Properties Resources map[string]*Resource Edges []*Edge OutputDeclarations map[string]OutputDeclaration @@ -33,52 +33,27 @@ type ( } Edge struct { - From ResourceRef - To ResourceRef + From template.ResourceRef + To template.ResourceRef Data construct.EdgeData } - ResourceRef struct { - ConstructURN model.URN - ResourceKey string - Property string - Type ResourceRefType - } - OutputDeclaration struct { Name string Ref construct.PropertyRef Value any } - - ResourceRefType string - InterpolationSourceKey string - InterpolationContext struct { - AllowedKeys []InterpolationSourceKey - Construct *Construct - } ) -func NewInterpolationContext(c *Construct, keys []InterpolationSourceKey) InterpolationContext { - if c == nil { - c = &Construct{} - } - return InterpolationContext{ - AllowedKeys: keys, - Construct: c, - } -} - -func (c *Construct) GetInput(name string) (value any, ok bool) { - value, ok = c.Inputs[name] - return value, ok +func (c *Construct) GetInputValue(name string) (value any, err error) { + return c.Inputs.GetProperty(name) } -func (c *Construct) GetTemplateResourcesIterator() Iterator[string, ResourceTemplate] { +func (c *Construct) GetTemplateResourcesIterator() template.Iterator[string, template.ResourceTemplate] { return c.ConstructTemplate.ResourcesIterator() } -func (c *Construct) GetTemplateEdges() []EdgeTemplate { +func (c *Construct) GetTemplateEdges() []template.EdgeTemplate { return c.ConstructTemplate.Edges } @@ -90,16 +65,16 @@ func (c *Construct) SetEdges(edges []*Edge) { c.Edges = edges } -func (c *Construct) GetInputRules() []InputRuleTemplate { +func (c *Construct) GetInputRules() []template.InputRuleTemplate { return c.ConstructTemplate.InputRules } -func (c *Construct) GetTemplateOutputs() map[string]OutputTemplate { +func (c *Construct) GetTemplateOutputs() map[string]template.OutputTemplate { return c.ConstructTemplate.Outputs } -func (c *Construct) GetPropertySource() *PropertySource { - return NewPropertySource(map[string]any{ +func (c *Construct) GetPropertySource() *template.PropertySource { + return template.NewPropertySource(map[string]any{ "inputs": c.Inputs, "resources": c.Resources, "edges": c.Edges, @@ -128,14 +103,14 @@ func (c *Construct) DeclareOutput(key string, declaration OutputDeclaration) { c.OutputDeclarations[key] = declaration } -func (c *Construct) GetTemplateInputs() map[string]InputTemplate { - return c.ConstructTemplate.Inputs -} - func (c *Construct) GetURN() model.URN { return c.URN } +func (c *Construct) GetInputs() construct.Properties { + return c.Inputs +} + func (e *Edge) PrettyPrint() string { return e.From.String() + " -> " + e.To.String() } @@ -144,124 +119,75 @@ func (e *Edge) String() string { return e.PrettyPrint() + " :: " + fmt.Sprintf("%v", e.Data) } -func (r *ResourceRef) String() string { - if r.Type == ResourceRefTypeIaC { - return fmt.Sprintf("%s#%s", r.ResourceKey, r.Property) +// OrderedBindings returns the bindings sorted by priority (lowest to highest). +// If two bindings have the same priority, their declaration order is preserved. +func (c *Construct) OrderedBindings() []*Binding { + if len(c.Bindings) == 0 { + return nil } - return r.ResourceKey -} -const ( - // ResourceRefTypeTemplate is a reference to a resource template and will be fully resolved prior to constraint generation - // e.g., ${resources:resourceName.property} or ${resources:resourceName} - ResourceRefTypeTemplate ResourceRefType = "template" - // ResourceRefTypeIaC is a reference to an infrastructure as code resource that will be resolved by the engine - // e.g., ${resources:resourceName#property} - ResourceRefTypeIaC ResourceRefType = "iac" - // ResourceRefTypeInterpolated is an initial interpolation reference to a resource. - // An interpolated value will be evaluated during initial processing and will be converted to one of the other types. - ResourceRefTypeInterpolated ResourceRefType = "interpolated" -) + sorted := append([]*Binding{}, c.Bindings...) -const ( - // InputsInterpolation is an interpolation source used to interpolate values from the construct's inputs - InputsInterpolation InterpolationSourceKey = "inputs" - // ResourcesInterpolation is an interpolation source used to interpolate values from the construct's resources - ResourcesInterpolation InterpolationSourceKey = "resources" - // EdgesInterpolation is an interpolation source used to interpolate values from the construct's edges - EdgesInterpolation InterpolationSourceKey = "edges" - // MetaInterpolation is an interpolation source used to interpolate values from the construct's metadata - // (i.e., non-properties fields) - MetaInterpolation InterpolationSourceKey = "meta" - // BindingInterpolation is an interpolation source used to interpolate values - // from a binding's from/to constructs using the "from" and "to" interpolation prefixes respectively. - FromInterpolation InterpolationSourceKey = "from" - ToInterpolation InterpolationSourceKey = "to" -) + sort.SliceStable(sorted, func(i, j int) bool { + if c.Bindings[i].Priority == c.Bindings[j].Priority { + return i < j + } + return c.Bindings[i].Priority < c.Bindings[j].Priority + }) + return sorted +} -var ( - ResourceInterpolationContext = []InterpolationSourceKey{InputsInterpolation, ResourcesInterpolation, ResourcesInterpolation} - EdgeInterpolationContext = []InterpolationSourceKey{InputsInterpolation, ResourcesInterpolation, EdgesInterpolation} - OutputInterpolationContext = []InterpolationSourceKey{InputsInterpolation, ResourcesInterpolation, EdgesInterpolation, MetaInterpolation} - InputRuleInterpolationContext = []InterpolationSourceKey{InputsInterpolation, ResourcesInterpolation, EdgesInterpolation, MetaInterpolation} - BindingInterpolationContext = []InterpolationSourceKey{InputsInterpolation, ResourcesInterpolation, EdgesInterpolation, MetaInterpolation, FromInterpolation, ToInterpolation} -) +func (c *Construct) GetConstruct() *Construct { + return c +} + +func (c *Construct) ForEachInput(f func(input inputs2.Property) error) error { + return c.ConstructTemplate.ForEachInput(c.Inputs, f) +} -// NewConstruct creates a new Construct instance from the given URN and inputs. +// newConstruct creates a new Construct instance from the given URN and inputs. // The URN must be a construct URN. // Any inputs that are not provided will be populated with default values from the construct template. -func NewConstruct(constructUrn model.URN, inputs map[string]any) (*Construct, error) { - if _, ok := inputs["Name"]; ok { +func (ce *ConstructEvaluator) newConstruct(constructUrn model.URN, i construct.Properties) (*Construct, error) { + if _, ok := i["Name"]; ok { return nil, errors.New("'Name' is a reserved input key") } if !constructUrn.IsResource() || constructUrn.Type != "construct" { return nil, errors.New("invalid construct URN") } - // Add the construct name to the inputs - inputs["Name"] = constructUrn.ResourceID - - var templateId ConstructTemplateId + /// Load the construct template + var templateId inputs2.ConstructType err := templateId.FromURN(constructUrn) if err != nil { return nil, err } - ct, err := loadConstructTemplate(templateId) + ct, err := template.LoadConstructTemplate(templateId) if err != nil { return nil, err } - populateDefaultInputValues(inputs, ct.Inputs) - - return &Construct{ + c := &Construct{ URN: constructUrn, ConstructTemplate: ct, Meta: make(map[string]any), - Inputs: inputs, + Inputs: make(construct.Properties), Resources: make(map[string]*Resource), Edges: []*Edge{}, OutputDeclarations: make(map[string]OutputDeclaration), Outputs: make(map[string]any), InitialGraph: construct.NewGraph(), - }, nil -} - -func populateDefaultInputValues(inputs map[string]any, templates map[string]InputTemplate) { - for key, t := range templates { - if _, hasVal := inputs[key]; !hasVal && t.Default != nil { - defaultValue := t.Default - if t.Type == "path" { - pStr, ok := defaultValue.(string) - if !ok { - continue - } - var err error - defaultValue, err = handlePathInput(pStr) - if err != nil { - zap.S().Warnf("failed to handle path input %s=%v: %v", key, pStr, err) - continue - } - } - inputs[key] = defaultValue - } - zap.S().Debugf("populated default value for input %s=%v", key, t) } -} -// OrderedBindings returns the bindings sorted by priority (lowest to highest). -// If two bindings have the same priority, their declaration order is preserved. -func (c *Construct) OrderedBindings() []*Binding { - if len(c.Bindings) == 0 { - return nil + // Add the construct name to the inputs + err = c.Inputs.SetProperty("Name", constructUrn.ResourceID) + if err != nil { + return nil, err } - sorted := append([]*Binding{}, c.Bindings...) - - sort.SliceStable(sorted, func(i, j int) bool { - if c.Bindings[i].Priority == c.Bindings[j].Priority { - return i < j - } - return c.Bindings[i].Priority < c.Bindings[j].Priority - }) - return sorted + err = ce.initializeInputs(c, i) + if err != nil { + return nil, err + } + return c, nil } diff --git a/pkg/k2/constructs/construct_evaluator.go b/pkg/k2/constructs/construct_evaluator.go index 9426fb727..1e43bf3a7 100644 --- a/pkg/k2/constructs/construct_evaluator.go +++ b/pkg/k2/constructs/construct_evaluator.go @@ -1,28 +1,25 @@ package constructs import ( - "bytes" "context" "errors" "fmt" + + "github.com/klothoplatform/klotho/pkg/async" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "reflect" - "regexp" - "strconv" + "slices" "strings" - "text/template" "github.com/dominikbraun/graph" - "github.com/klothoplatform/klotho/pkg/async" "github.com/klothoplatform/klotho/pkg/construct" "github.com/klothoplatform/klotho/pkg/engine" - "github.com/klothoplatform/klotho/pkg/engine/solution" stateconverter "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_converter" - statetemplate "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_template" "github.com/klothoplatform/klotho/pkg/k2/model" - "github.com/klothoplatform/klotho/pkg/k2/reflectutil" "github.com/klothoplatform/klotho/pkg/k2/stack" "github.com/klothoplatform/klotho/pkg/logging" - "go.uber.org/zap" ) type ConstructEvaluator struct { @@ -32,7 +29,7 @@ type ConstructEvaluator struct { stackStateManager *stack.StateManager stateConverter stateconverter.StateConverter - Constructs async.ConcurrentMap[model.URN, *Construct] + Constructs *async.ConcurrentMap[model.URN, *Construct] } func NewConstructEvaluator(sm *model.StateManager, ssm *stack.StateManager) (*ConstructEvaluator, error) { @@ -45,6 +42,7 @@ func NewConstructEvaluator(sm *model.StateManager, ssm *stack.StateManager) (*Co stateManager: sm, stackStateManager: ssm, stateConverter: stateConverter, + Constructs: &async.ConcurrentMap[model.URN, *Construct]{}, }, nil } @@ -53,7 +51,7 @@ func (ce *ConstructEvaluator) Evaluate(constructUrn model.URN, state model.State if err != nil { return engine.SolveRequest{}, fmt.Errorf("error evaluating construct %s: %w", constructUrn, err) } - err = ce.evaluateBindings(ci, state, ctx) + err = ce.evaluateBindings(ctx, ci) if err != nil { return engine.SolveRequest{}, fmt.Errorf("error evaluating bindings: %w", err) } @@ -75,387 +73,6 @@ func (ce *ConstructEvaluator) Evaluate(constructUrn model.URN, state model.State }, nil } -// Matches one or more interpolation groups in a string e.g., ${inputs:foo.bar}-baz-${resource:Boz} -var interpolationPattern = regexp.MustCompile(`\$\{([^:]+):([^}]+)}`) - -// Matches exactly one interpolation group e.g., ${inputs:foo.bar} -var isolatedInterpolationPattern = regexp.MustCompile(`^\$\{([^:]+):([^}]+)}$`) - -var spreadPattern = regexp.MustCompile(`\.\.\.}$`) - -/* - interpolateValue interpolates a value based on the context of the construct - rawValue is the value to interpolate. The format of a raw value is ${:} where prefix is the type of value to interpolate and key is the key to interpolate - - The key can be a path to a value in the context. For example, ${inputs:foo.bar} will interpolate the value of the key bar in the foo input. - The target of a dot-separated path can be a map or a struct. - The path can also include brackets to access an array. For example, ${inputs:foo[0].bar} will interpolate the value of the key bar in the first element of the foo input array. - A rawValue can contain a combination of interpolation expressions and literals. For example, "${inputs:foo.bar}-baz-${resource:Boz}" is a valid rawValue. -*/ -func (ce *ConstructEvaluator) interpolateValue(c InterpolationSource, rawValue any, ctx InterpolationContext) (any, error) { - if ref, ok := rawValue.(ResourceRef); ok { - switch ref.Type { - case ResourceRefTypeInterpolated: - return ce.interpolateValue(c, ref.ResourceKey, ctx) - case ResourceRefTypeTemplate: - ref.ConstructURN = ctx.Construct.URN - return ref, nil - default: - return rawValue, nil - } - } - - v := reflectutil.GetConcreteElement(reflect.ValueOf(rawValue)) - if !v.IsValid() { - return rawValue, nil - } - rawValue = v.Interface() - - switch v.Kind() { - case reflect.String: - return ce.interpolateString(c.GetPropertySource(), v.String(), ctx) - case reflect.Slice: - length := v.Len() - var interpolated []any - for i := 0; i < length; i++ { - // handle spread operator by injecting the spread value into the array at the current index - originalValue := reflectutil.GetConcreteValue(v.Index(i)) - if originalString, ok := originalValue.(string); ok && spreadPattern.MatchString(originalString) { - unspreadPath := originalString[:len(originalString)-4] + "}" - spreadValue, err := ce.interpolateValue(c, unspreadPath, ctx) - if err != nil { - return nil, err - } - - if spreadValue == nil { - continue - } - if reflect.TypeOf(spreadValue).Kind() != reflect.Slice { - return nil, errors.New("spread value must be a slice") - } - - for i := 0; i < reflect.ValueOf(spreadValue).Len(); i++ { - interpolated = append(interpolated, reflect.ValueOf(spreadValue).Index(i).Interface()) - } - continue - } - value, err := ce.interpolateValue(c, v.Index(i).Interface(), ctx) - if err != nil { - return nil, err - } - interpolated = append(interpolated, value) - } - return interpolated, nil - case reflect.Map: - keys := v.MapKeys() - interpolated := make(map[string]any) - for _, k := range keys { - key, err := ce.interpolateValue(c, k.Interface(), ctx) - if err != nil { - return nil, err - } - value, err := ce.interpolateValue(c, v.MapIndex(k).Interface(), ctx) - if err != nil { - return nil, err - } - interpolated[fmt.Sprint(key)] = value - } - return interpolated, nil - case reflect.Struct: - // Create a new instance of the struct - newStruct := reflect.New(v.Type()).Elem() - - // Interpolate each field - for i := 0; i < v.NumField(); i++ { - fieldName := v.Type().Field(i).Name - fieldValue, err := ce.interpolateValue(c, v.Field(i).Interface(), ctx) - if err != nil { - return nil, err - } - // Set the interpolated value to the field in the new struct - if fieldValue != nil { - newStruct.FieldByName(fieldName).Set(reflect.ValueOf(fieldValue)) - } - } - - // Return the new struct - return newStruct.Interface(), nil - default: - return rawValue, nil - } -} - -func (ce *ConstructEvaluator) interpolateString(ps *PropertySource, rawValue string, ctx InterpolationContext) (any, error) { - - // if the rawValue is an isolated interpolation expression, interpolate it and return the raw value - if isolatedInterpolationPattern.MatchString(rawValue) { - return ce.interpolateExpression(ps, rawValue, ctx) - } - - var err error - - // Replace each match in the rawValue (mixed expressions are always interpolated as strings) - interpolated := interpolationPattern.ReplaceAllStringFunc(rawValue, func(match string) string { - var val any - val, err = ce.interpolateExpression(ps, match, ctx) - return fmt.Sprint(val) - }) - if err != nil { - return nil, err - } - - return interpolated, nil -} - -func (ce *ConstructEvaluator) interpolateExpression(ps *PropertySource, match string, ctx InterpolationContext) (any, error) { - if ps == nil { - return nil, errors.New("property source is nil") - } - - // Split the match into prefix and key - parts := interpolationPattern.FindStringSubmatch(match) - prefix := parts[1] - key := parts[2] - - // Check if the prefix is allowed - allowed := false - for _, p := range ctx.AllowedKeys { - if p == InterpolationSourceKey(prefix) || p == FromInterpolation && strings.HasPrefix(prefix, "from.") || p == ToInterpolation && strings.HasPrefix(prefix, "to.") { - allowed = true - break - } - } - if !allowed { - return "", fmt.Errorf("interpolation prefix '%s' is not allowed in the current context", prefix) - } - - // Choose the correct root property from the source based on the prefix - var p any - ok := false - if prefix == "inputs" || prefix == "resources" || prefix == "edges" || prefix == "meta" || - strings.HasPrefix(prefix, "from.") || - strings.HasPrefix(prefix, "to.") { - p, ok = ps.GetProperty(prefix) - if !ok { - return nil, fmt.Errorf("could not get %s", prefix) - } - } else { - return nil, fmt.Errorf("invalid prefix: %s", prefix) - } - - prefixParts := strings.Split(prefix, ".") - - // associate any ResourceRefs with the URN of the property source they're being interpolated from - // if the prefix is "from" or "to", the URN of the property source is the "urn" field of that level in the property source - var refUrn model.URN - - if strings.HasSuffix(prefix, "resources") { - urnKey := "urn" - if prefixParts[0] == "from" || prefixParts[0] == "to" { - urnKey = fmt.Sprintf("%s.urn", prefixParts[0]) - } - psURN, ok := GetTypedProperty[model.URN](ps, urnKey) - if !ok { - psURN = ctx.Construct.URN - } - refUrn = psURN - } else { - propTrace, err := reflectutil.TracePath(reflect.ValueOf(p), key) - if err == nil { - refConstruct, ok := reflectutil.LastOfType[*Construct](propTrace) - if ok { - refUrn = refConstruct.URN - } - } - if refUrn.Equals(model.URN{}) { - refUrn = ctx.Construct.URN - } - } - - // return an IaC reference if the key matches the IaC reference pattern - if iacRefPattern.MatchString(key) { - return ResourceRef{ - ResourceKey: iacRefPattern.FindStringSubmatch(key)[1], - Property: iacRefPattern.FindStringSubmatch(key)[2], - Type: ResourceRefTypeIaC, - ConstructURN: refUrn, - }, nil - } - - // special cases for resources allowing for accessing the name of a resource directly instead of using .Id.Name - if prefix == "resources" || prefixParts[len(prefixParts)-1] == "resources" { - keyParts := strings.SplitN(key, ".", 2) - resourceKey := keyParts[0] - if len(keyParts) > 1 { - if path := keyParts[1]; path == "Name" { - return p.(map[string]*Resource)[resourceKey].Id.Name, nil - } - - } - } - - // Retrieve the value from the designated property source - value, err := getValueFromSource(p, key, false) - if err != nil { - zap.S().Debugf("could not get value from source: %s", err) - return nil, nil - } - - keyAndRef := strings.Split(key, "#") - var refProperty string - if len(keyAndRef) == 2 { - refProperty = keyAndRef[1] - } - - // If the value is a Resource, return a ResourceRef - if r, ok := value.(*Resource); ok { - return ResourceRef{ - ResourceKey: r.Id.String(), - Property: refProperty, - Type: ResourceRefTypeIaC, - ConstructURN: refUrn, - }, nil - } - - if r, ok := value.(ResourceRef); ok { - r.ConstructURN = refUrn - return r, nil - } - - // Replace the match with the value - return value, nil -} - -// iacRefPattern is a regular expression pattern that matches an IaC reference -// IaC references are in the format # - -var iacRefPattern = regexp.MustCompile(`^([a-zA-Z0-9_-]+)#([a-zA-Z0-9._-]+)$`) - -// getValueFromSource retrieves a value from a property source based on a key -// the flat parameter is used to determine if the key is a flat key or a path (mixed keys aren't supported at the moment) -// e.g (flat = true): key = "foo.bar" -> value = collection["foo."bar"], flat = false: key = "foo.bar" -> value = collection["foo"]["bar"] -func getValueFromSource(source any, key string, flat bool) (any, error) { - value := reflect.ValueOf(source) - - keyAndRef := strings.Split(key, "#") - if len(keyAndRef) > 2 { - return nil, fmt.Errorf("invalid engine reference property reference: %s", key) - } - - var refProperty string - if len(keyAndRef) == 2 { - refProperty = keyAndRef[1] - key = keyAndRef[0] - } - - // Split the key into parts if not flat - parts := []string{key} - if !flat { - parts = strings.Split(key, ".") - } - - var err error - var lastValidValue reflect.Value - lastValidIndex := -1 - - // Traverse the map/struct/array according to the parts - for i, part := range parts { - // Check if the part contains brackets - if strings.Contains(part, "[") && strings.HasSuffix(part, "]") { - // Split the part into the key and the index - keyAndIndex := strings.Split(strings.TrimRight(strings.TrimLeft(part, "["), "]"), "[") - key := keyAndIndex[0] - var index int - index, err = strconv.Atoi(keyAndIndex[1]) - if err != nil { - err = fmt.Errorf("could not parse index: %w", err) - break - } - - if r, ok := value.Interface().(*Resource); ok { - lastValidValue = reflect.ValueOf(r.Properties) - value, err = reflectutil.GetField(lastValidValue, part) - } else { - value, err = reflectutil.GetField(value, key) - } - if err != nil { - err = fmt.Errorf("could not get field: %w", err) - break - } - - value = reflectutil.GetConcreteElement(value) - kind := value.Kind() - - switch kind { - case reflect.Slice | reflect.Array: - if index >= value.Len() { - err = fmt.Errorf("index out of bounds: %d", index) - break - } - value = value.Index(index) - case reflect.Map: - value, err = reflectutil.GetField(value, key) - if err != nil { - err = fmt.Errorf("could not get field: %w", err) - break - } - default: - err = fmt.Errorf("invalid type: %s", kind) - } - } else { - // The part does not contain brackets - if value.Kind() == reflect.Map { - v := value.MapIndex(reflect.ValueOf(part)) - if v.IsValid() { - value = v - } else { - err = fmt.Errorf("could not get value for key: %s", key) - break - } - } else if r, ok := value.Interface().(*Resource); ok { - if len(parts) == 1 { - return ResourceRef{ - ResourceKey: part, - Property: refProperty, - Type: ResourceRefTypeTemplate, - }, nil - } else { - // if the parent is a resource, children are implicitly properties of the resource - lastValidValue = reflect.ValueOf(r.Properties) - value, err = reflectutil.GetField(lastValidValue, part) - if err != nil { - err = fmt.Errorf("could not get field: %w", err) - break - } - } - } else { - var rVal reflect.Value - rVal, err = reflectutil.GetField(value, part) - if err != nil { - err = fmt.Errorf("could not get field: %w", err) - break - } - value = rVal - } - } - if err != nil { - break - } - if i == len(parts)-1 { - return value.Interface(), nil - } - - lastValidValue = value - lastValidIndex = i - } - - if lastValidIndex > -1 { - return getValueFromSource(lastValidValue.Interface(), strings.Join(parts[lastValidIndex+1:], "."), true) - } - - return value.Interface(), err -} - /* evaluateInputRules evaluates the input rules of the construct @@ -474,15 +91,26 @@ input rules cannot use interpolation in the if condition in the example input() is a function that returns the value of the input with the given key */ -func (ce *ConstructEvaluator) evaluateInputRules(o InfraOwner, interpolationCtx InterpolationContext) error { +func (ce *ConstructEvaluator) evaluateInputRules(o InfraOwner) error { for _, rule := range o.GetInputRules() { - if err := ce.evaluateInputRule(o, rule, interpolationCtx); err != nil { + dv := &DynamicValueData{ + currentOwner: o, + } + + if err := ce.evaluateInputRule(dv, rule); err != nil { return fmt.Errorf("could not evaluate input rule: %w", err) } } return nil } +func (ce *ConstructEvaluator) evaluateInputRule(dv *DynamicValueData, rule template.InputRuleTemplate) error { + if rule.ForEach != "" { + return ce.evaluateForEachRule(dv, rule) + } + return ce.evaluateIfRule(dv, rule) +} + /* Evaluation Order: @@ -495,7 +123,6 @@ Evaluation Order: Binding Input Rules Binding Resources Binding Edges - Binding Conflict Resolvers */ func (ce *ConstructEvaluator) evaluateConstruct(constructUrn model.URN, state model.State, ctx context.Context) (*Construct, error) { @@ -504,30 +131,12 @@ func (ce *ConstructEvaluator) evaluateConstruct(constructUrn model.URN, state mo return nil, fmt.Errorf("could not get state state for construct: %s", constructUrn) } - inputs := make(map[string]any) - - templateId, err := ParseConstructTemplateId(constructUrn.Subtype) + inputs, err := ce.convertInputs(cState.Inputs) if err != nil { - return nil, fmt.Errorf("could not parse construct template id: %w", err) + return nil, fmt.Errorf("invalid inputs for construct: %w", err) } - ct, err := loadConstructTemplate(templateId) - if err != nil { - return nil, fmt.Errorf("could not load construct template: %w", err) - } - for k, v := range cState.Inputs { - inputTemplate, ok := ct.Inputs[k] - if !ok { - zap.S().Warnf("input %s not found in construct template", k) - } - v, err := ce.ResolveInput(k, v, inputTemplate) - if err != nil { - return nil, err - } - inputs[k] = v - } - - c, err := NewConstruct(constructUrn, inputs) + c, err := ce.newConstruct(constructUrn, inputs) if err != nil { return nil, fmt.Errorf("could not create construct: %w", err) } @@ -537,23 +146,23 @@ func (ce *ConstructEvaluator) evaluateConstruct(constructUrn model.URN, state mo return nil, fmt.Errorf("could not initialize bindings: %w", err) } - if err = ce.importResources(c, ctx); err != nil { + if err = ce.importResourcesFromInputs(c, ctx); err != nil { return nil, fmt.Errorf("could not import resources: %w", err) } - if err = ce.evaluateResources(c, NewInterpolationContext(c, ResourceInterpolationContext)); err != nil { + if err = ce.evaluateResources(c); err != nil { return nil, fmt.Errorf("could not evaluate resources: %w", err) } - if err = ce.evaluateEdges(c, NewInterpolationContext(c, EdgeInterpolationContext)); err != nil { + if err = ce.evaluateEdges(c); err != nil { return nil, fmt.Errorf("could not evaluate edges: %w", err) } - if err = ce.evaluateInputRules(c, NewInterpolationContext(c, InputRuleInterpolationContext)); err != nil { - return nil, fmt.Errorf("could not evaluate input rules: %w", err) + if err = ce.evaluateInputRules(c); err != nil { + return nil, err } - if err = ce.evaluateOutputs(c, NewInterpolationContext(c, OutputInterpolationContext)); err != nil { + if err = ce.evaluateOutputs(c); err != nil { return nil, fmt.Errorf("could not evaluate outputs: %w", err) } @@ -605,34 +214,19 @@ func (ce *ConstructEvaluator) initBindings(c *Construct, state model.State) erro return errors.New("to is a reserved input name") } - b, err := ce.newBinding(c.URN, d.From, d.To) + b, err := ce.newBinding(c.URN, d) if err != nil { return fmt.Errorf("could not create binding: %w", err) } - inputs := make(map[string]any) - for key, inputTemplate := range b.BindingTemplate.Inputs { - mVal, ok := d.Inputs[key] - if !ok { - continue - } - resolvedValue, err := ce.ResolveInput(key, mVal, inputTemplate) - if err != nil { - return fmt.Errorf("could not resolve input: %w", err) - } - inputs[key] = resolvedValue - } - populateDefaultInputValues(inputs, b.BindingTemplate.Inputs) - b.Inputs = inputs - c.Bindings = append(c.Bindings, b) } return nil } -func (ce *ConstructEvaluator) evaluateBindings(c *Construct, state model.State, ctx context.Context) error { +func (ce *ConstructEvaluator) evaluateBindings(ctx context.Context, c *Construct) error { for _, binding := range c.OrderedBindings() { - if err := ce.evaluateBinding(c, binding, state, ctx); err != nil { + if err := ce.evaluateBinding(ctx, binding); err != nil { return fmt.Errorf("could not evaluate binding: %w", err) } } @@ -640,75 +234,43 @@ func (ce *ConstructEvaluator) evaluateBindings(c *Construct, state model.State, return nil } -func getBinding(list []model.Binding, urn model.URN) (model.Binding, bool) { - for _, b := range list { - if b.URN.Equals(urn) { - return b, true - } +func (ce *ConstructEvaluator) evaluateBinding(ctx context.Context, b *Binding) error { + if b == nil { + return fmt.Errorf("binding is nil") } - return model.Binding{}, false -} + owner := b.Owner + if owner == nil { + return fmt.Errorf("binding owner is nil") -func (ce *ConstructEvaluator) evaluateBinding(owner *Construct, b *Binding, state model.State, ctx context.Context) error { - if owner == nil || b == nil { - return errors.New("construct or binding is nil") } - if b.BindingTemplate.From.Name == "" || b.BindingTemplate.To.Name == "" { return nil // assume that this binding does not modify the current construct } - var err error - - if b.From != nil { - cState, ok := state.Constructs[b.From.URN.ResourceID] - if !ok { - return fmt.Errorf("could not get state state for binding: (%s) %s -> %s", owner.URN, b.From.URN, b.To.URN) - } - bState, ok := getBinding(cState.Bindings, b.To.URN) - if !ok { - return fmt.Errorf("could not find binding by URN (%s)", b.To.URN) - } - - inputs := make(map[string]any) - for k, v := range bState.Inputs { - inputTemplate, ok := b.BindingTemplate.Inputs[k] - if !ok { - zap.S().Warnf("input %s not found in binding template", k) - } - v, err := ce.ResolveInput(k, v, inputTemplate) - if err != nil { - return err - } - inputs[k] = v - } - } - if err = ce.importResources(b, ctx); err != nil { + if err := ce.importResourcesFromInputs(b, ctx); err != nil { return fmt.Errorf("could not import resources: %w", err) } if b.From != nil && owner.URN.Equals(b.From.GetURN()) { // only import "to" resources if the binding is from the current construct - if err = ce.importBindingToResources(ctx, b); err != nil { + if err := ce.importBindingToResources(ctx, b); err != nil { return fmt.Errorf("could not import binding resources: %w", err) } } - interpolationCtx := NewInterpolationContext(owner, BindingInterpolationContext) - - if err = ce.evaluateResources(b, interpolationCtx); err != nil { + if err := ce.evaluateResources(b); err != nil { return fmt.Errorf("could not evaluate resources: %w", err) } - if err = ce.evaluateEdges(b, interpolationCtx); err != nil { + if err := ce.evaluateEdges(b); err != nil { return fmt.Errorf("could not evaluate edges: %w", err) } - if err = ce.evaluateInputRules(b, interpolationCtx); err != nil { + if err := ce.evaluateInputRules(b); err != nil { return fmt.Errorf("could not evaluate input rules: %w", err) } - if err = ce.evaluateOutputs(b, interpolationCtx); err != nil { + if err := ce.evaluateOutputs(b); err != nil { return fmt.Errorf("could not evaluate outputs: %w", err) } @@ -719,13 +281,18 @@ func (ce *ConstructEvaluator) evaluateBinding(owner *Construct, b *Binding, stat return nil } -func (ce *ConstructEvaluator) evaluateEdges(c InfraOwner, interpolationCtx InterpolationContext) error { - for _, edge := range c.GetTemplateEdges() { - e, err := ce.resolveEdge(c, edge, interpolationCtx) +func (ce *ConstructEvaluator) evaluateEdges(o InfraOwner) error { + dv := &DynamicValueData{ + currentOwner: o, + propertySource: o.GetPropertySource(), + } + + for _, edge := range o.GetTemplateEdges() { + e, err := ce.resolveEdge(dv, edge) if err != nil { return fmt.Errorf("could not resolve edge: %w", err) } - c.SetEdges(append(c.GetEdges(), e)) + o.SetEdges(append(o.GetEdges(), e)) } return nil } @@ -827,14 +394,19 @@ func edgeExists(edges []*Edge, newEdge *Edge) bool { return false } -func (ce *ConstructEvaluator) evaluateResources(o ResourceOwner, interpolationCtx InterpolationContext) error { +func (ce *ConstructEvaluator) evaluateResources(o InfraOwner) error { var err error - i := o.GetTemplateResourcesIterator() - i.ForEach(func(key string, resource ResourceTemplate) error { + dv := &DynamicValueData{ + currentOwner: o, + propertySource: o.GetPropertySource(), + } + + ri := o.GetTemplateResourcesIterator() + ri.ForEach(func(key string, resource template.ResourceTemplate) error { var r *Resource - r, err = ce.resolveResource(o, key, resource, interpolationCtx) + r, err = ce.resolveResource(dv, key, resource) if err != nil { - return stopIteration + return template.StopIteration } o.SetResource(key, r) return nil @@ -845,9 +417,9 @@ func (ce *ConstructEvaluator) evaluateResources(o ResourceOwner, interpolationCt return nil } -func GetPropertyFunc(ps *PropertySource) func(string) any { +func GetPropertyFunc(ps *template.PropertySource, path string) func(string) any { return func(key string) any { - i, ok := ps.GetProperty(fmt.Sprintf("inputs.%s", key)) + i, ok := ps.GetProperty(fmt.Sprintf("%s.%s", path, key)) if !ok { return nil } @@ -855,71 +427,166 @@ func GetPropertyFunc(ps *PropertySource) func(string) any { } } -func (ce *ConstructEvaluator) templateFunctions(ps *PropertySource) template.FuncMap { - funcs := template.FuncMap{} - funcs["inputs"] = GetPropertyFunc(ps) - return funcs +func (ce *ConstructEvaluator) evaluateForEachRule(dv *DynamicValueData, rule template.InputRuleTemplate) error { + parentPrefix := dv.resourceKeyPrefix + + ctx := DynamicValueContext{ + constructs: ce.Constructs, + } + var selected bool + if err := ctx.ExecuteUnmarshal(rule.ForEach, dv, &selected); err != nil { + return fmt.Errorf("result parsing failed: %w", err) + } + + if !selected { + return nil + } + + for _, hasNext := dv.currentSelection.Next(); hasNext; _, hasNext = dv.currentSelection.Next() { + prefix, err := ce.interpolateValue(dv, rule.Prefix) + if err != nil { + return fmt.Errorf("could not interpolate resource prefix: %w", err) + } + + dv := &DynamicValueData{ + currentOwner: dv.currentOwner, + currentSelection: dv.currentSelection, + propertySource: dv.propertySource, + } + + if prefix != "" && prefix != nil { + if parentPrefix != "" { + dv.resourceKeyPrefix = strings.Join([]string{parentPrefix, fmt.Sprintf("%s", prefix)}, ".") + } else { + dv.resourceKeyPrefix = fmt.Sprintf("%s", prefix) + } + } else { + dv.resourceKeyPrefix = parentPrefix + } + + ri := rule.Do.ResourcesIterator() + ri.ForEach(func(key string, resource template.ResourceTemplate) error { + if dv.resourceKeyPrefix != "" { + key = fmt.Sprintf("%s.%s", dv.resourceKeyPrefix, key) + } + + r, err := ce.resolveResource(dv, key, resource) + if err != nil { + return fmt.Errorf("could not resolve resource %s : %w", key, err) + } + dv.currentOwner.SetResource(key, r) + return nil + }) + + for _, edge := range rule.Do.Edges { + e, err := ce.resolveEdge(dv, edge) + if err != nil { + return fmt.Errorf("could not resolve edge: %w", err) + } + dv.currentOwner.SetEdges(append(dv.currentOwner.GetEdges(), e)) + } + + for _, rule := range rule.Do.Rules { + if err := ce.evaluateInputRule(dv, rule); err != nil { + return fmt.Errorf("could not evaluate input rule: %w", err) + } + + } + } + + return nil } -func (ce *ConstructEvaluator) evaluateInputRule(o InfraOwner, rule InputRuleTemplate, interpolationCtx InterpolationContext) error { - tmpl, err := template.New("input_rule").Option("missingkey=zero").Funcs(ce.templateFunctions(o.GetPropertySource())).Parse(rule.If) +func (ce *ConstructEvaluator) evaluateIfRule(dv *DynamicValueData, rule template.InputRuleTemplate) error { + parentPrefix := dv.resourceKeyPrefix + + prefix, err := ce.interpolateValue(dv, rule.Prefix) if err != nil { - return fmt.Errorf("template parsing failed for input rule: %s: %w", rule.If, err) - } - var rawResult bytes.Buffer - if err := tmpl.Execute(&rawResult, nil); err != nil { - return fmt.Errorf("template execution failed: %w", err) + return fmt.Errorf("could not interpolate resource prefix: %w", err) } - result := rawResult.String() - // If the input (eg 'field') is nil and the 'if' statement just uses '{{ inputs "field" }}', - // then the string result will be ''. Make sure we don't interpret that as a true condition. - executeThen := result != "" && result != "" && strings.ToLower(result) != "false" + dv = &DynamicValueData{ + currentOwner: dv.currentOwner, + currentSelection: dv.currentSelection, + propertySource: dv.propertySource, + } - var body ConditionalExpressionTemplate - if executeThen { - body = rule.Then + if prefix != "" && prefix != nil { + if parentPrefix != "" { + dv.resourceKeyPrefix = strings.Join([]string{parentPrefix, fmt.Sprintf("%s", prefix)}, ".") + } else { + dv.resourceKeyPrefix = fmt.Sprintf("%s", prefix) + } } else { - body = rule.Else + dv.resourceKeyPrefix = parentPrefix } - for key, resource := range body.Resources { - rp, err := ce.interpolateValue(o, resource, interpolationCtx) - if err != nil { - return fmt.Errorf("could not interpolate resource %s: %w", key, err) + ctx := DynamicValueContext{ + constructs: ce.Constructs, + } + + var boolResult bool + err = ctx.ExecuteUnmarshal(rule.If, dv, &boolResult) + + if err != nil { + return fmt.Errorf("result parsing failed: %w", err) + } + executeThen := boolResult + + var body template.ConditionalExpressionTemplate + if executeThen && rule.Then != nil { + body = *rule.Then + } else if rule.Else != nil { + body = *rule.Else + } + + ri := body.ResourcesIterator() + ri.ForEach(func(key string, resource template.ResourceTemplate) error { + if dv.resourceKeyPrefix != "" { + key = fmt.Sprintf("%s.%s", dv.resourceKeyPrefix, key) } - rt := rp.(ResourceTemplate) - r, err := ce.resolveResource(o, key, rt, interpolationCtx) + r, err := ce.resolveResource(dv, key, resource) if err != nil { return fmt.Errorf("could not resolve resource %s: %w", key, err) } - o.SetResource(key, r) - } + dv.currentOwner.SetResource(key, r) + return nil + }) for _, edge := range body.Edges { - e, err := ce.resolveEdge(o, edge, interpolationCtx) + e, err := ce.resolveEdge(dv, edge) if err != nil { return fmt.Errorf("could not resolve edge: %w", err) } - o.SetEdges(append(o.GetEdges(), e)) + dv.currentOwner.SetEdges(append(dv.currentOwner.GetEdges(), e)) } + + for _, rule := range body.Rules { + if err := ce.evaluateInputRule(dv, rule); err != nil { + return fmt.Errorf("could not evaluate input rule: %w", err) + } + } + return nil } -func (ce *ConstructEvaluator) resolveResource(o ResourceOwner, key string, rt ResourceTemplate, interpolationCtx InterpolationContext) (*Resource, error) { +func (ce *ConstructEvaluator) resolveResource(dv *DynamicValueData, key string, rt template.ResourceTemplate) (*Resource, error) { // update the resource if it already exists - resource, ok := o.GetResource(key) + if dv.currentOwner == nil { + return nil, fmt.Errorf("current owner is nil") + } + resource, ok := dv.currentOwner.GetResource(key) if !ok { resource = &Resource{Properties: map[string]any{}} } - tmpl, err := ce.interpolateValue(o, rt, interpolationCtx) + tmpl, err := ce.interpolateValue(dv, rt) if err != nil { return nil, fmt.Errorf("could not interpolate resource %s: %w", key, err) } - resTmpl := tmpl.(ResourceTemplate) + resTmpl := tmpl.(template.ResourceTemplate) typeParts := strings.Split(resTmpl.Type, ":") if len(typeParts) != 2 && resTmpl.Type != "" { return nil, fmt.Errorf("invalid resource type: %s", resTmpl.Type) @@ -968,41 +635,58 @@ func (ce *ConstructEvaluator) resolveResource(o ResourceOwner, key string, rt Re return resource, nil } -func (ce *ConstructEvaluator) resolveEdge(c InfraOwner, edge EdgeTemplate, interpolationCtx InterpolationContext) (*Edge, error) { - from, err := ce.interpolateValue(c, edge.From, interpolationCtx) +func (ce *ConstructEvaluator) resolveEdge(dv *DynamicValueData, edge template.EdgeTemplate) (*Edge, error) { + from, err := ce.interpolateValue(dv, edge.From) if err != nil { return nil, err } if from == nil { return nil, fmt.Errorf("from is nil") } - to, err := ce.interpolateValue(c, edge.To, interpolationCtx) + to, err := ce.interpolateValue(dv, edge.To) if err != nil { return nil, err } if to == nil { return nil, fmt.Errorf("to is nil") } - data, err := ce.interpolateValue(c, edge.Data, interpolationCtx) + data, err := ce.interpolateValue(dv, edge.Data) if err != nil { return nil, err } return &Edge{ - From: from.(ResourceRef), - To: to.(ResourceRef), + From: from.(template.ResourceRef), + To: to.(template.ResourceRef), Data: data.(construct.EdgeData), }, nil } -func (ce *ConstructEvaluator) evaluateOutputs(o InfraOwner, interpolationCtx InterpolationContext) error { - for key, output := range o.GetTemplateOutputs() { - output, err := ce.interpolateValue(o, output, interpolationCtx) +func (ce *ConstructEvaluator) evaluateOutputs(o InfraOwner) error { + // sort the keys of the outputs alphabetically to ensure deterministic ordering + sortKeys := func(m map[string]template.OutputTemplate) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + slices.Sort(keys) + return keys + } + + outputs := o.GetTemplateOutputs() + keys := sortKeys(outputs) + for _, key := range keys { + ot := outputs[key] + dv := &DynamicValueData{ + currentOwner: o, + propertySource: o.GetPropertySource(), + } + output, err := ce.interpolateValue(dv, ot) if err != nil { return fmt.Errorf("failed to interpolate value for output %s: %w", key, err) } - outputTemplate, ok := output.(OutputTemplate) + outputTemplate, ok := output.(template.OutputTemplate) if !ok { return fmt.Errorf("invalid output template for output %s", key) } @@ -1010,11 +694,11 @@ func (ce *ConstructEvaluator) evaluateOutputs(o InfraOwner, interpolationCtx Int var value any var ref construct.PropertyRef - r, ok := outputTemplate.Value.(ResourceRef) + r, ok := outputTemplate.Value.(template.ResourceRef) if !ok { value = outputTemplate.Value } else { - serializedRef, err := ce.serializeRef(o, r) + serializedRef, err := ce.marshalRef(o, r) if err != nil { return fmt.Errorf("failed to serialize ref for output %s: %w", key, err) } @@ -1047,200 +731,56 @@ func (ce *ConstructEvaluator) evaluateOutputs(o InfraOwner, interpolationCtx Int return nil } -var constructTypePattern = regexp.MustCompile(`^Construct\(([\w.-]+)\)$`) - -func (ce *ConstructEvaluator) importFrom(ctx context.Context, o InfraOwner, ic *Construct) error { - log := logging.GetLogger(ctx).Sugar() - - // TODO: DS - consider whether to include transitive resource imports - - initGraph := o.GetInitialGraph() - sol := ic.Solution - - stackState, hasState := ce.stackStateManager.ConstructStackState[ic.URN] - - // NOTE(gg): using topo sort to get all resources, order doesn't matter - resourceIds, err := construct.TopologicalSort(sol.DataflowGraph()) - if err != nil { - return fmt.Errorf("could not get resources from %s solution: %w", ic.URN, err) - } - resources := make(map[construct.ResourceId]*construct.Resource) - for _, rId := range resourceIds { - var liveStateRes *construct.Resource - if hasState { - if state, ok := stackState.Resources[rId]; ok { - liveStateRes, err = ce.stateConverter.ConvertResource(stateconverter.Resource{ - Urn: string(state.URN), - Type: string(state.Type), - Outputs: state.Outputs, - }) - if err != nil { - return fmt.Errorf("could not convert state for %s.%s: %w", ic.URN, rId, err) - } - log.Debugf("Imported %s from state", rId) - } - } - originalRes, err := sol.DataflowGraph().Vertex(rId) - if err != nil { - return fmt.Errorf("could not get resource %s.%s from solution: %w", ic.URN, rId, err) - } - - tmpl, err := sol.KnowledgeBase().GetResourceTemplate(rId) - if err != nil { - return fmt.Errorf("could not get resource template %s.%s: %w", ic.URN, rId, err) - } - - props := make(construct.Properties) - for k, v := range originalRes.Properties { - props[k] = v - } - hasImportId := false - // set a fake import id, otherwise index.ts will have things like - // Type.get("name", ) - for k, prop := range tmpl.Properties { - if prop.Details().Required && prop.Details().DeployTime { - if liveStateRes == nil { - if ce.DryRun > 0 { - props[k] = fmt.Sprintf("preview(id=%s)", rId) - hasImportId = true - continue - } else { - return fmt.Errorf("could not get live state resource %s (%s)", ic.URN, rId) - } - } - liveIdProp, err := liveStateRes.GetProperty(k) - if err != nil { - return fmt.Errorf("could not get property %s for %s: %w", k, rId, err) - } - props[k] = liveIdProp - hasImportId = true - } - } - if !hasImportId { - continue - } - - res := &construct.Resource{ - ID: originalRes.ID, - Properties: props, - Imported: true, - } - - log.Debugf("Imported %s from solution", rId) - - if err := initGraph.AddVertex(res); err != nil { - return fmt.Errorf("could not create imported resource %s from %s: %w", rId, ic.URN, err) - } - resources[rId] = res - } - err = filterImportProperties(resources) - if err != nil { - return fmt.Errorf("could not filter import properties for %s: %w", ic.URN, err) - } - - edges, err := sol.DataflowGraph().Edges() - if err != nil { - return fmt.Errorf("could not get edges from %s solution: %w", ic.URN, err) - } - for _, e := range edges { - err := initGraph.AddEdge(e.Source, e.Target, func(ep *graph.EdgeProperties) { - ep.Data = e.Properties.Data - }) - switch { - case err == nil: - log.Debugf("Imported edge %s -> %s from solution", e.Source, e.Target) - - case errors.Is(err, graph.ErrVertexNotFound): - log.Debugf("Skipping import edge %s -> %s from solution", e.Source, e.Target) - - default: - return fmt.Errorf("could not create imported edge %s -> %s from %s: %w", e.Source, e.Target, ic.URN, err) +func (ce *ConstructEvaluator) convertInputs(inputs map[string]model.Input) (construct.Properties, error) { + props := make(construct.Properties) + for k, v := range inputs { + if ce.DryRun == 0 && v.Status != model.InputStatusResolved { + return nil, fmt.Errorf("input %s is not resolved", k) } + props[k] = v.Value } + return props, nil +} - return nil +type HasInputs interface { + ForEachInput(f func(input property.Property) error) error + GetInputs() construct.Properties } -// filterImportProperties filters out any references to resources that were skipped from importing. -func filterImportProperties(resources map[construct.ResourceId]*construct.Resource) error { - var errs []error - clearProp := func(id construct.ResourceId, path construct.PropertyPath) { - if err := path.Remove(nil); err != nil { - errs = append(errs, - fmt.Errorf("error clearing %s: %w", construct.PropertyRef{Resource: id, Property: path.String()}, err), - ) - } - } - for id, r := range resources { - _ = r.WalkProperties(func(path construct.PropertyPath, _ error) error { - switch v := path.Get().(type) { - case construct.ResourceId: - if _, ok := resources[v]; !ok { - clearProp(id, path) +func (ce *ConstructEvaluator) initializeInputs(c HasInputs, i construct.Properties) error { + var inputErrors error + _ = c.ForEachInput(func(input property.Property) error { + v, err := i.GetProperty(input.Details().Path) + if err == nil { + if (v == nil || v == input.ZeroValue()) && input.Details().Required { + inputErrors = errors.Join(inputErrors, fmt.Errorf("input %s is required", input.Details().Path)) + return nil + } + if err = input.SetProperty(c.GetInputs(), v); err != nil { + inputErrors = errors.Join(inputErrors, err) + return nil + } + } else if errors.Is(err, construct.ErrPropertyDoesNotExist) { + if dv, err := input.GetDefaultValue(DynamicValueContext{}, nil); err == nil { + if dv == nil { + dv = input.ZeroValue() } - - case construct.PropertyRef: - if _, ok := resources[v.Resource]; !ok { - clearProp(id, path) + if (dv == nil || dv == input.ZeroValue()) && input.Details().Required { + inputErrors = errors.Join(inputErrors, fmt.Errorf("input %s is required", input.Details().Path)) + return nil + } + if dv == nil { + return nil // no default value (e.g., for collections or other types with type arguments, i.e., generics) + } + if err = input.SetProperty(c.GetInputs(), dv); err != nil { + inputErrors = errors.Join(inputErrors, err) + return nil } } - return nil - }) - } - return errors.Join(errs...) -} - -func (ce *ConstructEvaluator) importResources(o InfraOwner, ctx context.Context) error { - for iName, i := range o.GetTemplateInputs() { - // parse construct type from input type in the form of Construct(type) - // get the construct from the evaluator if it exists and is the correct type or return an error - // then go through the resources of the construct and add them to the imported resources of the current construct - // if the resource is not found, return an error - if i.Type == "Construct" { - return errors.New("input of type Construct must have a type specified in the form of Construct") - } - if !constructTypePattern.MatchString(i.Type) { - continue // skip the input if it is not a construct - } - - resolvedInput, ok := o.GetInput(iName) - if !ok { - return fmt.Errorf("could not find resolved input %s", iName) - } - - ic, ok := resolvedInput.(*Construct) - if !ok { - return fmt.Errorf("value %v of input %s is not a construct", iName, resolvedInput) - } - - if err := ce.importFrom(ctx, o, ic); err != nil { - return fmt.Errorf("could not import resources from %s: %w", ic.URN, err) + } else { + inputErrors = errors.Join(inputErrors, err) } - } - return nil -} - -func (ce *ConstructEvaluator) importBindingToResources(ctx context.Context, b *Binding) error { - return ce.importFrom(ctx, b, b.To) -} - -func (ce *ConstructEvaluator) RegisterOutputValues(urn model.URN, outputs map[string]any) { - if c, ok := ce.Constructs.Get(urn); ok { - c.Outputs = outputs - } -} - -func (ce *ConstructEvaluator) AddSolution(urn model.URN, sol solution.Solution) { - // panic is fine here if urn isn't in map - // will only happen in programmer error cases - c, _ := ce.Constructs.Get(urn) - c.Solution = sol -} - -func loadStateConverter() (stateconverter.StateConverter, error) { - templates, err := statetemplate.LoadStateTemplates("pulumi") - if err != nil { - return nil, err - } - return stateconverter.NewStateConverter("pulumi", templates), nil + return nil + }) + return inputErrors } diff --git a/pkg/k2/constructs/construct_evaluator_test.go b/pkg/k2/constructs/construct_evaluator_test.go index 83999a32f..7fbe0acf9 100644 --- a/pkg/k2/constructs/construct_evaluator_test.go +++ b/pkg/k2/constructs/construct_evaluator_test.go @@ -4,6 +4,8 @@ import ( "reflect" "testing" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/klothoplatform/klotho/pkg/construct" "github.com/klothoplatform/klotho/pkg/k2/model" "github.com/stretchr/testify/assert" @@ -58,73 +60,93 @@ func TestInterpolateValue(t *testing.T) { tests := []struct { name string rawValue any - ctx InterpolationContext + data DynamicValueData expected any hasError bool }{ { name: "Simple string interpolation", rawValue: "${inputs:stringInput}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "Hello", }, { name: "Integer interpolation", rawValue: "${inputs:intInput}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: 42, }, { name: "Map value interpolation", rawValue: "${inputs:mapInput.key1}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "value1", }, { name: "Slice value interpolation", rawValue: "${inputs:sliceInput[1]}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "b", }, { name: "Resource property interpolation", rawValue: "${resources:testResource.prop1}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "value1", }, { name: "Struct field interpolation", rawValue: "${inputs:structInput.Field1}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "Hello", }, { name: "Struct interpolation", rawValue: "${inputs:structInput}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: simpleStruct, }, { name: "Mixed string interpolation", rawValue: "Prefix ${inputs:stringInput} Suffix", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "Prefix Hello Suffix", }, { name: "IaC reference interpolation", rawValue: "${resources:testResource#prop1}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), - expected: ResourceRef{ + data: DynamicValueData{ + currentOwner: mockConstruct, + }, + expected: template.ResourceRef{ ResourceKey: "testResource", Property: "prop1", - Type: ResourceRefTypeIaC, + Type: template.ResourceRefTypeIaC, ConstructURN: model.URN{ResourceID: "test-construct"}, }, }, { name: "Invalid interpolation prefix", rawValue: "${invalid:key}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, hasError: true, }, { @@ -136,7 +158,9 @@ func TestInterpolateValue(t *testing.T) { Field1: "${inputs:stringInput}", Field2: 42, }, - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: struct { Field1 string Field2 int @@ -152,7 +176,9 @@ func TestInterpolateValue(t *testing.T) { "key2": 42, "${inputs:stringInput}": "value3", }, - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: map[string]any{ "key1": "Hello", "key2": 42, @@ -162,94 +188,140 @@ func TestInterpolateValue(t *testing.T) { { name: "Slice interpolation", rawValue: []any{"${inputs:stringInput}", 42}, - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: []any{"Hello", 42}, }, { name: "ResourceRef interpolation", - rawValue: ResourceRef{ResourceKey: "testResource", Property: "prop1", Type: ResourceRefTypeInterpolated}, - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + rawValue: template.ResourceRef{ResourceKey: "testResource", Property: "prop1", Type: template.ResourceRefTypeInterpolated}, + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "testResource", }, { name: "ResourceRef template type", - rawValue: ResourceRef{ResourceKey: "testResource", Property: "prop1", Type: ResourceRefTypeTemplate}, - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), - expected: ResourceRef{ResourceKey: "testResource", Property: "prop1", Type: ResourceRefTypeTemplate, ConstructURN: model.URN{ResourceID: "test-construct"}}, + rawValue: template.ResourceRef{ResourceKey: "testResource", Property: "prop1", Type: template.ResourceRefTypeTemplate}, + data: DynamicValueData{ + currentOwner: mockConstruct, + }, + expected: template.ResourceRef{ResourceKey: "testResource", Property: "prop1", Type: template.ResourceRefTypeTemplate, ConstructURN: model.URN{ResourceID: "test-construct"}}, }, { name: "Nested map interpolation", rawValue: map[string]any{"outer": map[string]any{"inner": "${inputs:stringInput}"}}, - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: map[string]any{"outer": map[string]any{"inner": "Hello"}}, }, { name: "Nested slice interpolation", rawValue: []any{"${inputs:stringInput}", []any{"nested", "${inputs:intInput}"}}, - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: []any{"Hello", []any{"nested", 42}}, }, { name: "Non-existent input", rawValue: "${inputs:nonExistentInput}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: nil, }, { name: "Non-existent resource", rawValue: "${resources:nonExistentResource.prop}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, hasError: false, }, { name: "Non-existent resource property", rawValue: "${resources:testResource.nonExistentProp}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: nil, }, { name: "Invalid array index", rawValue: "${inputs:sliceInput[invalid]}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: nil, }, { name: "Out of bounds array index", rawValue: "${inputs:sliceInput[10]}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: nil, }, { name: "Multiple interpolations in a string", rawValue: "Start ${inputs:stringInput} middle ${inputs:intInput} end", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "Start Hello middle 42 end", }, { name: "Mixed string interpolation with slice interpolation", rawValue: "${inputs:stringInput} ${inputs:sliceInput}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "Hello [a b c]", }, { name: "Mixed string interpolation with map interpolation", rawValue: "${inputs:stringInput} ${inputs:mapInput}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, // the dynamic key has not been interpolated (that would occur in a separate step) expected: "Hello map[${inputs:stringInput}:value3 key1:value1 key2:2]", }, { name: "Mixed string interpolation with struct interpolation", rawValue: "${inputs:stringInput} ${inputs:stringerStructInput}", - ctx: NewInterpolationContext(mockConstruct, ResourceInterpolationContext), + data: DynamicValueData{ + currentOwner: mockConstruct, + }, expected: "Hello Hello", }, + { + name: "Go template interpolation", + rawValue: "Hello {{ .Inputs.stringInput }}", + data: DynamicValueData{currentOwner: mockConstruct}, + expected: "Hello Hello", + }, + { + name: "Go template evaluated before interpolation string", + rawValue: `Hello ${inputs:{{"stringInput"}}}`, + data: DynamicValueData{currentOwner: mockConstruct}, + expected: "Hello Hello", + }, + { + name: "Interleaved Go template and interpolation", + rawValue: "Hello {{ .Inputs.stringInput }} ${inputs:stringInput}", + data: DynamicValueData{currentOwner: mockConstruct}, + expected: "Hello Hello Hello", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ce := &ConstructEvaluator{} - result, err := ce.interpolateValue(mockConstruct, tt.rawValue, tt.ctx) + result, err := ce.interpolateValue(&tt.data, tt.rawValue) if tt.hasError { assert.Error(t, err) @@ -261,44 +333,6 @@ func TestInterpolateValue(t *testing.T) { } } -// Additional helper function to test the templateFunctions -func TestTemplateFunctions(t *testing.T) { - ce := &ConstructEvaluator{} - ps := &PropertySource{ - source: reflect.ValueOf(map[string]any{ - "inputs": map[string]any{ - "stringInput": "Hello", - "intInput": 42, - }, - }), - } - - funcs := ce.templateFunctions(ps) - inputsFunc := funcs["inputs"].(func(string) any) - - assert.Equal(t, "Hello", inputsFunc("stringInput")) - assert.Equal(t, 42, inputsFunc("intInput")) - assert.Nil(t, inputsFunc("nonExistentInput")) -} - -// Test for GetPropertyFunc -func TestGetPropertyFunc(t *testing.T) { - ps := &PropertySource{ - source: reflect.ValueOf(map[string]any{ - "inputs": map[string]any{ - "stringInput": "Hello", - "intInput": 42, - }, - }), - } - - getProperty := GetPropertyFunc(ps) - - assert.Equal(t, "Hello", getProperty("stringInput")) - assert.Equal(t, 42, getProperty("intInput")) - assert.Nil(t, getProperty("nonExistentInput")) -} - func TestGetValueFromSource(t *testing.T) { tests := []struct { name string @@ -445,7 +479,11 @@ func TestGetValueFromSource(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := getValueFromSource(tt.source, tt.key, tt.flat) + ce, err := NewConstructEvaluator(nil, nil) + if !assert.NoError(t, err) { + return + } + result, err := ce.getValueFromSource(tt.source, tt.key, tt.flat) if tt.err != "" { assert.Error(t, err) @@ -457,3 +495,53 @@ func TestGetValueFromSource(t *testing.T) { }) } } + +func TestNewConstruct(t *testing.T) { + tests := []struct { + name string + urn string + inputs map[string]any + expectedErr bool + expectedName string + }{ + { + name: "Valid inputs", + urn: "urn:accountid:project:dev::construct/klotho.aws.Bucket:my-bucket", + inputs: map[string]any{"someKey": "someValue"}, + expectedErr: false, + expectedName: "my-bucket", + }, + { + name: "Reserved Name key", + urn: "urn:accountid:project:dev::construct/klotho.aws.Bucket:my-bucket", + inputs: map[string]any{"Name": "invalid"}, + expectedErr: true, + }, + { + name: "Invalid URN type", + urn: "urn:accountid:project:dev::resource/klotho.aws.Bucket:invalidType", + inputs: map[string]any{"someKey": "someValue"}, + expectedErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + constructURN, err := model.ParseURN(tt.urn) + assert.NoError(t, err) + + ce, err := NewConstructEvaluator(nil, nil) + if !assert.NoError(t, err) { + return + } + + c, err := ce.newConstruct(*constructURN, tt.inputs) + if tt.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedName, c.Inputs["Name"]) + } + }) + } +} diff --git a/pkg/k2/constructs/construct_marshaller.go b/pkg/k2/constructs/construct_marshaller.go index a7b9f3144..6fd0f1cd3 100644 --- a/pkg/k2/constructs/construct_marshaller.go +++ b/pkg/k2/constructs/construct_marshaller.go @@ -5,10 +5,12 @@ import ( "reflect" "sort" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/klothoplatform/klotho/pkg/reflectutil" + "github.com/klothoplatform/klotho/pkg/construct" "github.com/klothoplatform/klotho/pkg/engine/constraints" "github.com/klothoplatform/klotho/pkg/k2/model" - "github.com/klothoplatform/klotho/pkg/k2/reflectutil" ) type ( @@ -37,7 +39,7 @@ func (m *ConstructMarshaller) Marshal(constructURN model.URN) (constraints.Const for _, e := range c.Edges { edgeConstraints, err := m.marshalEdge(c, e) if err != nil { - return nil, fmt.Errorf("could not marshal edge: %w", err) + return nil, fmt.Errorf("could not marshall edge: %w", err) } cs = append(cs, edgeConstraints...) } @@ -45,7 +47,7 @@ func (m *ConstructMarshaller) Marshal(constructURN model.URN) (constraints.Const for _, o := range c.OutputDeclarations { outputConstraints, err := m.marshalOutput(o) if err != nil { - return nil, fmt.Errorf("could not marshal output: %w", err) + return nil, fmt.Errorf("could not marshall output: %w", err) } cs = append(cs, outputConstraints...) } @@ -65,7 +67,7 @@ func (m *ConstructMarshaller) marshalResource(o InfraOwner, r *Resource) (constr v, err := m.marshalRefs(o, v) if err != nil { - return nil, fmt.Errorf("could not marshall resource properties: %w", err) + return nil, fmt.Errorf("could not marshal resource properties: %w", err) } cs = append(cs, &constraints.ResourceConstraint{ Operator: "equals", @@ -82,7 +84,7 @@ func (m *ConstructMarshaller) marshalResource(o InfraOwner, r *Resource) (constr func (m *ConstructMarshaller) marshalEdge(o InfraOwner, e *Edge) (constraints.ConstraintList, error) { var from construct.ResourceId - ref, err := m.ConstructEvaluator.serializeRef(o, e.From) + ref, err := m.ConstructEvaluator.marshalRef(o, e.From) if err != nil { return nil, fmt.Errorf("could not serialize from resource id: %w", err) } @@ -96,7 +98,7 @@ func (m *ConstructMarshaller) marshalEdge(o InfraOwner, e *Edge) (constraints.Co } var to construct.ResourceId - ref, err = m.ConstructEvaluator.serializeRef(o, e.To) + ref, err = m.ConstructEvaluator.marshalRef(o, e.To) if err != nil { return nil, fmt.Errorf("could not serialize to resource id: %w", err) } @@ -110,7 +112,7 @@ func (m *ConstructMarshaller) marshalEdge(o InfraOwner, e *Edge) (constraints.Co } v, err := m.marshalRefs(o, e.Data) if err != nil { - return nil, fmt.Errorf("could not marshall resource properties: %w", err) + return nil, fmt.Errorf("could not marshal resource properties: %w", err) } return constraints.ConstraintList{&constraints.EdgeConstraint{ @@ -142,14 +144,18 @@ func (m *ConstructMarshaller) marshalOutput(o OutputDeclaration) (constraints.Co } func (m *ConstructMarshaller) marshalRefs(o InfraOwner, rawVal any) (any, error) { + if rawVal == nil { + return rawVal, nil + } + switch val := rawVal.(type) { - case *ResourceRef: + case *template.ResourceRef: if val == nil { - return val, nil + return rawVal, nil } - return m.ConstructEvaluator.serializeRef(o, *val) - case ResourceRef: - return m.ConstructEvaluator.serializeRef(o, val) + return m.ConstructEvaluator.marshalRef(o, *val) + case template.ResourceRef: + return m.ConstructEvaluator.marshalRef(o, val) case construct.ResourceId, construct.PropertyRef: return rawVal, nil } @@ -173,7 +179,7 @@ func (m *ConstructMarshaller) marshalRefs(o InfraOwner, rawVal any) (any, error) fieldValue := reflectutil.GetConcreteElement(field) if field.CanInterface() { - if _, ok := fieldValue.Interface().(ResourceRef); ok { + if _, ok := fieldValue.Interface().(template.ResourceRef); ok { // If we encounter a ResourceRef in a struct, we skip it // Since the result is not also a ResourceRef continue diff --git a/pkg/k2/constructs/construct_marshaller_test.go b/pkg/k2/constructs/construct_marshaller_test.go index 8b4414cc0..84f067e95 100644 --- a/pkg/k2/constructs/construct_marshaller_test.go +++ b/pkg/k2/constructs/construct_marshaller_test.go @@ -4,6 +4,9 @@ import ( "reflect" "testing" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/stretchr/testify/require" + "github.com/klothoplatform/klotho/pkg/async" "github.com/klothoplatform/klotho/pkg/construct" "github.com/klothoplatform/klotho/pkg/engine/constraints" @@ -15,15 +18,15 @@ import ( func TestConstructMarshaller(t *testing.T) { mockEvaluator := &ConstructEvaluator{ - Constructs: async.ConcurrentMap[model.URN, *Construct]{}, + Constructs: &async.ConcurrentMap[model.URN, *Construct]{}, } constructURN, _ := model.ParseURN("urn:accountid:project:dev::construct/klotho.aws.Bucket:my-bucket") mockConstruct := &Construct{ URN: *constructURN, Edges: []*Edge{ { - From: ResourceRef{ResourceKey: "aws:s3:test:bucket"}, - To: ResourceRef{ResourceKey: "aws:ec2:test:instance"}, + From: template.ResourceRef{ResourceKey: "aws:s3:test:bucket"}, + To: template.ResourceRef{ResourceKey: "aws:ec2:test:instance"}, Data: construct.EdgeData{}, }, }, @@ -159,13 +162,13 @@ func TestConstructMarshaller(t *testing.T) { }, Edges: []*Edge{ { - From: ResourceRef{ResourceKey: "aws:s3:test:bucket"}, - To: ResourceRef{ResourceKey: "aws:ec2:test:instance"}, + From: template.ResourceRef{ResourceKey: "aws:s3:test:bucket"}, + To: template.ResourceRef{ResourceKey: "aws:ec2:test:instance"}, Data: construct.EdgeData{}, }, { - From: ResourceRef{ResourceKey: "aws:ec2:test:instance"}, - To: ResourceRef{ResourceKey: "aws:lambda:test:function"}, + From: template.ResourceRef{ResourceKey: "aws:ec2:test:instance"}, + To: template.ResourceRef{ResourceKey: "aws:lambda:test:function"}, Data: construct.EdgeData{}, }, }, @@ -214,9 +217,9 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { o: testConstruct, rawVal: map[string]any{ "key1": "value1", - "key2": ResourceRef{ + "key2": template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, @@ -236,9 +239,9 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { args: args{ o: testConstruct, rawVal: []any{ - ResourceRef{ + template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, @@ -258,28 +261,27 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { o: testConstruct, rawVal: &struct { Field1 string - Field2 ResourceRef + Field2 template.ResourceRef }{ Field1: "value1", - Field2: ResourceRef{ + Field2: template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, }, want: &struct { Field1 string - Field2 ResourceRef + Field2 template.ResourceRef }{ Field1: "value1", - Field2: ResourceRef{ + Field2: template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, - wantErr: false, }, { name: "marshal nested struct with settable ResourceRef", @@ -288,16 +290,16 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { rawVal: &struct { Field1 string Nested struct { - Field2 ResourceRef + Field2 template.ResourceRef } }{ Field1: "value1", Nested: struct { - Field2 ResourceRef + Field2 template.ResourceRef }{ - Field2: ResourceRef{ + Field2: template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, @@ -306,29 +308,28 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { want: &struct { Field1 string Nested struct { - Field2 ResourceRef + Field2 template.ResourceRef } }{ Field1: "value1", Nested: struct { - Field2 ResourceRef + Field2 template.ResourceRef }{ - Field2: ResourceRef{ + Field2: template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, }, - wantErr: false, }, { name: "marshal interface with ResourceRef", args: args{ o: testConstruct, - rawVal: interface{}(ResourceRef{ + rawVal: interface{}(template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }), }, @@ -337,7 +338,6 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { Type: "s3_bucket", Name: "mybucket", }, - wantErr: false, }, { name: "marshal unsupported type", @@ -345,35 +345,31 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { o: testConstruct, rawVal: func() {}, // Using a function type to trigger the default case }, - want: func() {}, // Expecting the same unsupported type to be returned - wantErr: false, + want: func() {}, // Expecting the same unsupported type to be returned }, { name: "marshal nil pointer", args: args{ o: testConstruct, - rawVal: (*ResourceRef)(nil), + rawVal: (*template.ResourceRef)(nil), }, - want: (*ResourceRef)(nil), - wantErr: false, + want: (*template.ResourceRef)(nil), }, { name: "marshal nil map", args: args{ o: testConstruct, - rawVal: (map[string]ResourceRef)(nil), + rawVal: (map[string]template.ResourceRef)(nil), }, - want: (map[string]ResourceRef)(nil), - wantErr: false, + want: (map[string]template.ResourceRef)(nil), }, { name: "marshal nil slice", args: args{ o: testConstruct, - rawVal: ([]ResourceRef)(nil), + rawVal: ([]template.ResourceRef)(nil), }, - want: ([]ResourceRef)(nil), - wantErr: false, + want: ([]template.ResourceRef)(nil), }, { name: "marshal invalid value", @@ -381,8 +377,7 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { o: testConstruct, rawVal: nil, }, - want: nil, - wantErr: false, + want: nil, }, { name: "marshal zero value", @@ -390,8 +385,7 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { o: testConstruct, rawVal: struct{}{}, }, - want: struct{}{}, - wantErr: false, + want: struct{}{}, }, { name: "marshal pointer to struct with unsettable ResourceRef", @@ -399,37 +393,36 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { o: testConstruct, rawVal: &struct { Field1 string - field2 ResourceRef + field2 template.ResourceRef }{ Field1: "value1", - field2: ResourceRef{ + field2: template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, }, want: &struct { Field1 string - field2 ResourceRef + field2 template.ResourceRef }{ Field1: "value1", - field2: ResourceRef{ + field2: template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, - wantErr: false, }, { name: "marshal pointer to interface with ResourceRef", args: args{ o: testConstruct, rawVal: func() interface{} { - val := ResourceRef{ + val := template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, } return &val @@ -440,15 +433,14 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { Type: "s3_bucket", Name: "mybucket", }, - wantErr: false, }, { name: "marshal pointer to ResourceRef", args: args{ o: testConstruct, - rawVal: &ResourceRef{ + rawVal: &template.ResourceRef{ ResourceKey: "aws:s3_bucket:mybucket", - Type: ResourceRefTypeTemplate, + Type: template.ResourceRefTypeTemplate, ConstructURN: *constructURN, }, }, @@ -457,7 +449,6 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { Type: "s3_bucket", Name: "mybucket", }, - wantErr: false, }, { name: "marshal struct with unsettable int field", @@ -478,7 +469,6 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { Field1: "value1", field2: 100, }, - wantErr: false, }, { name: "marshal struct with settable int field", @@ -499,7 +489,6 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { Field1: "value1", Field2: 100, }, - wantErr: false, }, { name: "marshal struct with pointer to int field", @@ -520,7 +509,6 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { Field1: "value1", Field2: func() *int { v := 200; return &v }(), }, - wantErr: false, }, } @@ -535,9 +523,10 @@ func TestConstructMarshaller_marshalRefs(t *testing.T) { ConstructEvaluator: evaluator, } got, err := marshaller.marshalRefs(tt.args.o, tt.args.rawVal) - if (err != nil) != tt.wantErr { - t.Errorf("ConstructMarshaller.marshalRefs() error = %v, wantErr %v", err, tt.wantErr) - return + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) } if tt.name == "marshal unsupported type" { // Since we can't compare function types directly, we use reflection to check the type diff --git a/pkg/k2/constructs/construct_test.go b/pkg/k2/constructs/construct_test.go index a13235567..c0f092c88 100644 --- a/pkg/k2/constructs/construct_test.go +++ b/pkg/k2/constructs/construct_test.go @@ -1,59 +1,19 @@ package constructs import ( - "path/filepath" "testing" - "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + properties2 "github.com/klothoplatform/klotho/pkg/k2/constructs/template/properties" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" "github.com/klothoplatform/klotho/pkg/k2/model" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" + + "github.com/klothoplatform/klotho/pkg/construct" "github.com/stretchr/testify/assert" ) -func TestNewConstruct(t *testing.T) { - tests := []struct { - name string - urn string - inputs map[string]any - expectedErr bool - expectedName string - }{ - { - name: "Valid inputs", - urn: "urn:accountid:project:dev::construct/klotho.aws.Bucket:my-bucket", - inputs: map[string]any{"someKey": "someValue"}, - expectedErr: false, - expectedName: "my-bucket", - }, - { - name: "Reserved Name key", - urn: "urn:accountid:project:dev::construct/klotho.aws.Bucket:my-bucket", - inputs: map[string]any{"Name": "invalid"}, - expectedErr: true, - }, - { - name: "Invalid URN type", - urn: "urn:accountid:project:dev::resource/klotho.aws.Bucket:invalidType", - inputs: map[string]any{"someKey": "someValue"}, - expectedErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - constructURN, err := model.ParseURN(tt.urn) - assert.NoError(t, err) - - c, err := NewConstruct(*constructURN, tt.inputs) - if tt.expectedErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expectedName, c.Inputs["Name"]) - } - }) - } -} - func TestGetInput(t *testing.T) { c := &Construct{ Inputs: map[string]any{ @@ -63,29 +23,30 @@ func TestGetInput(t *testing.T) { } tests := []struct { - name string - key string - expected any - shouldFind bool + name string + key string + expected any + wantErr bool }{ { - name: "Existing key", - key: "key1", - expected: "value1", - shouldFind: true, + name: "Existing key", + key: "key1", + expected: "value1", }, { - name: "Non-existing key", - key: "nonexistent", - expected: nil, - shouldFind: false, + name: "Non-existing key", + key: "nonexistent", + expected: nil, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - value, found := c.GetInput(tt.key) - assert.Equal(t, tt.shouldFind, found) + value, err := c.GetInputValue(tt.key) + if tt.wantErr { + require.Error(t, err) + } assert.Equal(t, tt.expected, value) }) } @@ -125,17 +86,25 @@ func TestOrderedBindings(t *testing.T) { } func TestGetTemplateResourcesIterator(t *testing.T) { - mockResources := map[string]ResourceTemplate{ - "res1": {Type: "type1", Name: "name1", Namespace: "namespace1", Properties: map[string]any{"prop1": "value1"}}, - "res2": {Type: "type2", Name: "name2", Namespace: "namespace2", Properties: map[string]any{"prop2": "value2"}}, - } - mockTemplate := ConstructTemplate{ - Resources: mockResources, - resourceOrder: []string{"res1", "res2"}, - } + mockTemplate, err := parseConstructTemplate(` +resources: + res1: + type: type1 + name: name1 + namespace: namespace1 + properties: + prop1: value1 + res2: + type: type2 + name: name2 + namespace: namespace2 + properties: + prop2: value2 +`) + require.NoError(t, err) c := &Construct{ - ConstructTemplate: mockTemplate, + ConstructTemplate: *mockTemplate, } iter := c.GetTemplateResourcesIterator() @@ -153,78 +122,23 @@ func TestGetTemplateResourcesIterator(t *testing.T) { } } -func TestPopulateDefaultInputValues(t *testing.T) { - tests := []struct { - name string - inputs map[string]any - templates map[string]InputTemplate - expected map[string]any - }{ - { - name: "Populate default path value", - inputs: map[string]any{}, - templates: map[string]InputTemplate{ - "pathInput": { - Default: "default/path", - Type: "path", - }, - }, - expected: map[string]any{ - "pathInput": getAbsolutePath(t, "default/path"), - }, - }, - { - name: "Populate non-path default value", - inputs: map[string]any{}, - templates: map[string]InputTemplate{ - "simpleInput": { - Default: "default-value", - }, - }, - expected: map[string]any{ - "simpleInput": "default-value", - }, - }, - { - name: "Existing value should not be overwritten", - inputs: map[string]any{ - "simpleInput": "existing-value", - }, - templates: map[string]InputTemplate{ - "simpleInput": { - Default: "default-value", - }, - }, - expected: map[string]any{ - "simpleInput": "existing-value", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - populateDefaultInputValues(tt.inputs, tt.templates) - for key, expectedValue := range tt.expected { - assert.Equal(t, expectedValue, tt.inputs[key]) - } - }) +func parseConstructTemplate(yamlStr string) (*template.ConstructTemplate, error) { + mockTemplate := &template.ConstructTemplate{} + err := yaml.Unmarshal([]byte(yamlStr), mockTemplate) + if err != nil { + return nil, err } -} - -func getAbsolutePath(t *testing.T, path string) string { - absPath, err := filepath.Abs(path) - assert.NoError(t, err) - return absPath + return mockTemplate, nil } func TestConstructMethods(t *testing.T) { // Common setup - edgeTemplates := []EdgeTemplate{{}, {}} - inputRules := []InputRuleTemplate{{}, {}} - outputTemplates := map[string]OutputTemplate{"output1": {}, "output2": {}} - inputTemplates := map[string]InputTemplate{"input1": {}, "input2": {}} + edgeTemplates := []template.EdgeTemplate{{}, {}} + inputRules := []template.InputRuleTemplate{{}, {}} + outputTemplates := map[string]template.OutputTemplate{"output1": {}, "output2": {}} + inputTemplates := template.NewProperties(map[string]property.Property{"input1": &properties2.StringProperty{}, "input2": &properties2.StringProperty{}}) initialGraph := construct.NewGraph() - resourceTemplates := map[string]ResourceTemplate{"resource1": {}, "resource2": {}} + resourceTemplates := map[string]template.ResourceTemplate{"resource1": {}, "resource2": {}} resources := map[string]*Resource{"resource1": {}, "resource2": {}} inputs := map[string]any{"input1": "value1"} meta := map[string]any{"meta1": "value1"} @@ -238,28 +152,27 @@ func TestConstructMethods(t *testing.T) { assert.NoError(t, err) edges = append(edges, &Edge{ - From: ResourceRef{ + From: template.ResourceRef{ ConstructURN: *fromURN, ResourceKey: "from-resource", Property: "from-property", - Type: ResourceRefTypeIaC, + Type: template.ResourceRefTypeIaC, }, - To: ResourceRef{ + To: template.ResourceRef{ ConstructURN: *toURN, ResourceKey: "to-resource", Property: "to-property", - Type: ResourceRefTypeIaC, + Type: template.ResourceRefTypeIaC, }, Data: construct.EdgeData{}, }) - mockTemplate := ConstructTemplate{ - Edges: edgeTemplates, - InputRules: inputRules, - Outputs: outputTemplates, - Inputs: inputTemplates, - Resources: resourceTemplates, - resourceOrder: []string{"resource1", "resource2"}, + mockTemplate := template.ConstructTemplate{ + Edges: edgeTemplates, + InputRules: inputRules, + Outputs: outputTemplates, + Inputs: inputTemplates, + Resources: resourceTemplates, } urn, err := model.ParseURN("urn:accountid:project:dev::construct/klotho.aws.Bucket:my-bucket") @@ -319,11 +232,6 @@ func TestConstructMethods(t *testing.T) { assert.Equal(t, initialGraph, graph) }) - t.Run("GetTemplateInputs", func(t *testing.T) { - inputs := c.GetTemplateInputs() - assert.Len(t, inputs, len(inputTemplates)) - }) - t.Run("GetURN", func(t *testing.T) { retrievedURN := c.GetURN() assert.Equal(t, *urn, retrievedURN) @@ -343,11 +251,11 @@ func TestConstructMethods(t *testing.T) { refURN, err := model.ParseURN("urn:accountid:project:dev::construct/klotho.aws.Bucket:resource") assert.NoError(t, err) - ref := ResourceRef{ + ref := template.ResourceRef{ ConstructURN: *refURN, ResourceKey: "resource-key", Property: "property", - Type: ResourceRefTypeIaC, + Type: template.ResourceRefTypeIaC, } expected := "resource-key#property" assert.Equal(t, expected, ref.String()) diff --git a/pkg/k2/constructs/constructs.go b/pkg/k2/constructs/constructs.go index c245575fc..439c299da 100644 --- a/pkg/k2/constructs/constructs.go +++ b/pkg/k2/constructs/constructs.go @@ -1,12 +1,10 @@ package constructs import ( - "reflect" - "text/template" - "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" "github.com/klothoplatform/klotho/pkg/k2/model" - "github.com/klothoplatform/klotho/pkg/k2/reflectutil" ) type ( @@ -14,57 +12,33 @@ type ( GetResource(resourceId string) (resource *Resource, ok bool) SetResource(resourceId string, resource *Resource) GetResources() map[string]*Resource - GetTemplateResourcesIterator() Iterator[string, ResourceTemplate] - InterpolationSource + GetTemplateResourcesIterator() template.Iterator[string, template.ResourceTemplate] + template.InterpolationSource } EdgeOwner interface { - GetTemplateEdges() []EdgeTemplate + GetTemplateEdges() []template.EdgeTemplate GetEdges() []*Edge SetEdges(edges []*Edge) - InterpolationSource + template.InterpolationSource } InfraOwner interface { GetURN() model.URN - GetInputRules() []InputRuleTemplate + GetInputRules() []template.InputRuleTemplate ResourceOwner EdgeOwner - GetTemplateOutputs() map[string]OutputTemplate + GetTemplateOutputs() map[string]template.OutputTemplate DeclareOutput(key string, declaration OutputDeclaration) - GetTemplateInputs() map[string]InputTemplate - GetInput(name string) (value any, ok bool) + ForEachInput(f func(input property.Property) error) error + GetInputValue(name string) (value any, err error) GetInitialGraph() construct.Graph - } - - InterpolationSource interface { - GetPropertySource() *PropertySource - } - - PropertySource struct { - source reflect.Value - } - - TemplateFuncSupplier interface { - GetTemplateFuncs() template.FuncMap + GetConstruct() *Construct } ) -func NewPropertySource(source any) *PropertySource { - return &PropertySource{ - source: reflect.ValueOf(source), - } -} - -func (p *PropertySource) GetProperty(key string) (value any, ok bool) { - v, err := reflectutil.GetField(p.source, key) - if err != nil || !v.IsValid() { - return nil, false - } - return v.Interface(), true -} - -func (ce *ConstructEvaluator) serializeRef(owner InfraOwner, ref ResourceRef) (any, error) { +// marshalRef marshals a resource reference into a [construct.ResourceId] or [construct.PropertyRef] +func (ce *ConstructEvaluator) marshalRef(owner InfraOwner, ref template.ResourceRef) (any, error) { var resourceId construct.ResourceId r, ok := owner.GetResource(ref.ResourceKey) if ok { @@ -85,14 +59,3 @@ func (ce *ConstructEvaluator) serializeRef(owner InfraOwner, ref ResourceRef) (a return resourceId, nil } - -func GetTypedProperty[T any](source *PropertySource, key string) (T, bool) { - var typedField T - v, ok := source.GetProperty(key) - - if !ok { - return typedField, false - } - - return reflectutil.GetTypedValue[T](v) -} diff --git a/pkg/k2/constructs/dynamic_value.go b/pkg/k2/constructs/dynamic_value.go new file mode 100644 index 000000000..132bf6b2b --- /dev/null +++ b/pkg/k2/constructs/dynamic_value.go @@ -0,0 +1,322 @@ +package constructs + +import ( + "bytes" + "encoding/json" + "fmt" + "reflect" + "slices" + "strings" + "text/template" + + "github.com/klothoplatform/klotho/pkg/async" + "github.com/klothoplatform/klotho/pkg/construct" + template2 "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/properties" + "github.com/klothoplatform/klotho/pkg/k2/model" + "github.com/klothoplatform/klotho/pkg/templateutils" + "go.uber.org/zap" +) + +type DynamicValueContext struct { + constructs *async.ConcurrentMap[model.URN, *Construct] +} + +type DynamicValueData struct { + currentOwner InfraOwner + currentSelection DynamicValueSelection + propertySource *template2.PropertySource + resourceKeyPrefix string +} + +func (ctx DynamicValueContext) TemplateFunctions() template.FuncMap { + return templateutils.WithCommonFuncs(template.FuncMap{ + "fieldRef": ctx.FieldRef, + "pathAncestor": ctx.PathAncestor, + "pathAncestorExists": ctx.PathAncestorExists, + "toJSON": ctx.toJson, + }) +} + +func (ctx DynamicValueContext) Parse(tmpl string) (*template.Template, error) { + t, err := template.New("config").Funcs(ctx.TemplateFunctions()).Parse(tmpl) + return t, err +} + +func (ctx DynamicValueContext) ExecuteUnmarshal(tmpl string, data any, value any) error { + t, err := ctx.Parse(tmpl) + if err != nil { + return err + } + return ctx.ExecuteTemplateUnmarshal(t, data, value) +} + +func (ctx DynamicValueContext) Unmarshal(data *bytes.Buffer, v any) error { + return properties.UnmarshalAny(data, v) +} + +// ExecuteTemplateUnmarshal executes the template tmpl using data as input and unmarshals the value into v +func (ctx DynamicValueContext) ExecuteTemplateUnmarshal( + t *template.Template, + data any, + v any, +) error { + buf := new(bytes.Buffer) + if err := t.Execute(buf, data); err != nil { + return err + } + + if err := ctx.Unmarshal(buf, v); err != nil { + return fmt.Errorf("cannot decode template result '%s' into %T", buf, v) + } + + return nil +} + +// Self returns the owner of this dynamic value +func (data *DynamicValueData) Self() any { + return data.currentOwner +} + +// Selected returns the current selection in the dynamic value data +func (data *DynamicValueData) Selected() DynamicValueSelection { + return data.currentSelection +} + +func (data *DynamicValueData) Select(path string) bool { + var ps *template2.PropertySource + if data.currentSelection.Value != nil { + ps = template2.NewPropertySource(data.currentSelection.Value) + } else { + ps = data.propertySource + if ps == nil { + ps = data.currentOwner.GetPropertySource() + } + } + + if v, ok := ps.GetProperty(path); ok { + s := SelectItem(v) + data.currentSelection = s + return true + } + return false +} + +// Inputs returns the inputs of the current owner +func (data *DynamicValueData) Inputs() any { + ps := data.propertySource + if ps == nil { + ps = data.currentOwner.GetPropertySource() + } + val, _ := ps.GetProperty("inputs") + return val +} + +// Resources returns the resources of the current owner +func (data *DynamicValueData) Resources() any { + ps := data.propertySource + if ps == nil { + ps = data.currentOwner.GetPropertySource() + } + val, _ := ps.GetProperty("resources") + return val +} + +// Edges returns the edges of the current owner +func (data *DynamicValueData) Edges() any { + ps := data.propertySource + if ps == nil { + ps = data.currentOwner.GetPropertySource() + } + val, _ := ps.GetProperty("edges") + return val +} + +// Meta returns the metadata of the current owner +func (data *DynamicValueData) Meta() any { + ps := data.propertySource + if ps == nil { + ps = data.currentOwner.GetPropertySource() + } + val, _ := ps.GetProperty("meta") + return val +} + +func (data *DynamicValueData) Prefix() string { + return data.resourceKeyPrefix +} + +// From returns the 'from' construct if the current owner is a binding +func (data *DynamicValueData) From() any { + ps := data.propertySource + if ps == nil { + ps = data.currentOwner.GetPropertySource() + } + val, _ := ps.GetProperty("from") + return val +} + +// To returns the 'to' construct if the current owner is a binding +func (data *DynamicValueData) To() any { + ps := data.propertySource + if ps == nil { + ps = data.currentOwner.GetPropertySource() + } + val, _ := ps.GetProperty("to") + return val +} + +// Log is primarily used for debugging templates and only be used in development to log messages to the console +func (data *DynamicValueData) Log(level string, message string, args ...interface{}) string { + l := zap.L() + + ownerType := reflect.TypeOf(data.currentOwner).Kind().String() + ownerString := "unknown" + + l = l.With(zap.String(ownerType, ownerString)) + + switch strings.ToLower(level) { + case "debug": + l.Sugar().Debugf(message, args...) + case "info": + l.Sugar().Infof(message, args...) + case "warn": + l.Sugar().Warnf(message, args...) + case "error": + l.Sugar().Errorf(message, args...) + default: + l.Sugar().Warnf(message, args...) + } + return "" +} + +// toJson is used to return complex values that do not have TextUnmarshaler implemented +func (ctx DynamicValueContext) toJson(value any) (string, error) { + j, err := json.Marshal(value) + if err != nil { + return "", err + } + return string(j), nil +} + +func (ctx DynamicValueContext) PathAncestor(path construct.PropertyPath, depth int) (string, error) { + if depth < 0 { + return "", fmt.Errorf("depth must be >= 0") + } + if depth == 0 { + return path.String(), nil + } + if len(path) <= depth { + return "", fmt.Errorf("depth %d is greater than path length %d", depth, len(path)) + } + return path[:len(path)-depth].String(), nil +} + +func (ctx DynamicValueContext) PathAncestorExists(path construct.PropertyPath, depth int) bool { + return len(path) > depth +} + +// FieldRef returns a reference to `field` on `resource` (as a PropertyRef) +func (ctx DynamicValueContext) FieldRef(field string, resource any) (construct.PropertyRef, error) { + resId, err := TemplateArgToRID(resource) + if err != nil { + return construct.PropertyRef{}, err + } + + return construct.PropertyRef{ + Resource: resId, + Property: field, + }, nil +} + +func TemplateArgToRID(arg any) (construct.ResourceId, error) { + switch arg := arg.(type) { + case construct.ResourceId: + return arg, nil + + case construct.Resource: + return arg.ID, nil + + case string: + var resId construct.ResourceId + err := resId.UnmarshalText([]byte(arg)) + return resId, err + } + + return construct.ResourceId{}, fmt.Errorf("invalid argument type %T", arg) +} + +type DynamicValueSelection struct { + Source any + mapKeys []reflect.Value + next int + Key string + Value any + Index int +} + +func SelectItem(src any) DynamicValueSelection { + srcValue := reflect.ValueOf(src) + switch srcValue.Kind() { + case reflect.Map: + if !srcValue.IsValid() || srcValue.Len() == 0 { + return DynamicValueSelection{ + Source: src, + } + } + keys := srcValue.MapKeys() + slices.SortStableFunc(keys, func(i, j reflect.Value) int { + return strings.Compare(stringValue(i.Interface()), stringValue(j.Interface())) + }) + if len(keys) == 0 { + return DynamicValueSelection{ + Source: src, + } + } + return DynamicValueSelection{ + Source: src, + mapKeys: keys, + } + default: + return DynamicValueSelection{ + Source: src, + } + } +} + +// Next returns the next value in the selection and whether there are more values +// If the selection is a map, the key is also returned. +// If the selection is a slice, the index is returned instead. +// If there are no more values, the second return value is false. +// +// This function is intended to be used by an orchestration layer across multiple go templates +// and is unavailable inside the templates themselves +func (s *DynamicValueSelection) Next() (any, bool) { + srcValue := reflect.ValueOf(s.Source) + + if !srcValue.IsValid() { + return nil, false + } + + if len(s.mapKeys) > 0 { + if s.next >= len(s.mapKeys) { + return nil, false + } + key := s.mapKeys[s.next] + value := srcValue.MapIndex(key).Interface() + s.Value = value + s.next++ + return value, true + } + if s.Index >= srcValue.Len() { + return nil, false + } + value := srcValue.Index(s.Index).Interface() + s.Value = value + s.Index++ + return value, true +} + +func stringValue(v any) string { + return fmt.Sprintf("%v", v) +} diff --git a/pkg/k2/constructs/import_resources.go b/pkg/k2/constructs/import_resources.go new file mode 100644 index 000000000..d62f4a9e2 --- /dev/null +++ b/pkg/k2/constructs/import_resources.go @@ -0,0 +1,215 @@ +package constructs + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/dominikbraun/graph" + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/engine/solution" + stateconverter "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_converter" + statetemplate "github.com/klothoplatform/klotho/pkg/infra/state_reader/state_template" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/klothoplatform/klotho/pkg/k2/model" + "github.com/klothoplatform/klotho/pkg/logging" +) + +func (ce *ConstructEvaluator) importFrom(ctx context.Context, o InfraOwner, ic *Construct) error { + log := logging.GetLogger(ctx).Sugar() + initGraph := o.GetInitialGraph() + sol := ic.Solution + stackState, hasState := ce.stackStateManager.ConstructStackState[ic.URN] + + // NOTE(gg): using topo sort to get all resources, order doesn't matter + resourceIds, err := construct.TopologicalSort(sol.DataflowGraph()) + if err != nil { + return fmt.Errorf("could not get resources from %s solution: %w", ic.URN, err) + } + resources := make(map[construct.ResourceId]*construct.Resource) + for _, rId := range resourceIds { + var liveStateRes *construct.Resource + if hasState { + if state, ok := stackState.Resources[rId]; ok { + liveStateRes, err = ce.stateConverter.ConvertResource(stateconverter.Resource{ + Urn: string(state.URN), + Type: string(state.Type), + Outputs: state.Outputs, + }) + if err != nil { + return fmt.Errorf("could not convert state for %s.%s: %w", ic.URN, rId, err) + } + log.Debugf("Imported %s from state", rId) + } + } + originalRes, err := sol.DataflowGraph().Vertex(rId) + if err != nil { + return fmt.Errorf("could not get resource %s.%s from solution: %w", ic.URN, rId, err) + } + + tmpl, err := sol.KnowledgeBase().GetResourceTemplate(rId) + if err != nil { + return fmt.Errorf("could not get resource template %s.%s: %w", ic.URN, rId, err) + } + + props := make(construct.Properties) + for k, v := range originalRes.Properties { + props[k] = v + } + hasImportId := false + // set a fake import id, otherwise index.ts will have things like + // Type.get("name", ) + for k, prop := range tmpl.Properties { + if prop.Details().Required && prop.Details().DeployTime { + if liveStateRes == nil { + if ce.DryRun > 0 { + props[k] = fmt.Sprintf("preview(id=%s)", rId) + hasImportId = true + continue + } else { + return fmt.Errorf("could not get live state resource %s (%s)", ic.URN, rId) + } + } + liveIdProp, err := liveStateRes.GetProperty(k) + if err != nil { + return fmt.Errorf("could not get property %s for %s: %w", k, rId, err) + } + props[k] = liveIdProp + hasImportId = true + } + } + if !hasImportId { + continue + } + + res := &construct.Resource{ + ID: originalRes.ID, + Properties: props, + Imported: true, + } + + log.Debugf("Imported %s from solution", rId) + + if err := initGraph.AddVertex(res); err != nil { + return fmt.Errorf("could not create imported resource %s from %s: %w", rId, ic.URN, err) + } + resources[rId] = res + } + err = filterImportProperties(resources) + if err != nil { + return fmt.Errorf("could not filter import properties for %s: %w", ic.URN, err) + } + + edges, err := sol.DataflowGraph().Edges() + if err != nil { + return fmt.Errorf("could not get edges from %s solution: %w", ic.URN, err) + } + for _, e := range edges { + err := initGraph.AddEdge(e.Source, e.Target, func(ep *graph.EdgeProperties) { + ep.Data = e.Properties.Data + }) + switch { + case err == nil: + log.Debugf("Imported edge %s -> %s from solution", e.Source, e.Target) + + case errors.Is(err, graph.ErrVertexNotFound): + log.Debugf("Skipping import edge %s -> %s from solution", e.Source, e.Target) + + default: + return fmt.Errorf("could not create imported edge %s -> %s from %s: %w", e.Source, e.Target, ic.URN, err) + } + } + + return nil +} + +// filterImportProperties filters out any references to resources that were skipped from importing. +func filterImportProperties(resources map[construct.ResourceId]*construct.Resource) error { + var errs []error + clearProp := func(id construct.ResourceId, path construct.PropertyPath) { + if err := path.Remove(nil); err != nil { + errs = append(errs, + fmt.Errorf("error clearing %s: %w", construct.PropertyRef{Resource: id, Property: path.String()}, err), + ) + } + } + for id, r := range resources { + _ = r.WalkProperties(func(path construct.PropertyPath, _ error) error { + v, ok := path.Get() + if !ok { + return nil + } + switch v := v.(type) { + case construct.ResourceId: + if _, ok := resources[v]; !ok { + clearProp(id, path) + } + + case construct.PropertyRef: + if _, ok := resources[v.Resource]; !ok { + clearProp(id, path) + } + } + return nil + }) + } + return errors.Join(errs...) +} + +// importResourcesFromInputs imports resources from the construct-type inputs of the provided [InfraOwner], o. +// It returns an error if the input value is does not represent a valid construct +// or if importing the resources fails. +func (ce *ConstructEvaluator) importResourcesFromInputs(o InfraOwner, ctx context.Context) error { + return o.ForEachInput(func(i property.Property) error { + // if the input is a construct, import the resources from it + if !strings.HasPrefix(i.Type(), "construct") { + return nil + } + + resolvedInput, err := o.GetInputValue(i.Details().Path) + if err != nil { + return fmt.Errorf("could not get input %s: %w", i.Details().Path, err) + } + + cURN, ok := resolvedInput.(model.URN) + if !ok || !cURN.IsResource() || cURN.Type != "construct" { + return fmt.Errorf("input %s is not a construct URN", i.Details().Path) + } + + c, ok := ce.Constructs.Get(cURN) + if !ok { + return fmt.Errorf("could not find construct %s", cURN) + } + + if err := ce.importFrom(ctx, o, c); err != nil { + return fmt.Errorf("could not import resources from %s: %w", cURN, err) + } + return nil + }) +} + +func (ce *ConstructEvaluator) importBindingToResources(ctx context.Context, b *Binding) error { + return ce.importFrom(ctx, b, b.To) +} + +func (ce *ConstructEvaluator) RegisterOutputValues(urn model.URN, outputs map[string]any) { + if c, ok := ce.Constructs.Get(urn); ok { + c.Outputs = outputs + } +} + +func (ce *ConstructEvaluator) AddSolution(urn model.URN, sol solution.Solution) { + // panic is fine here if urn isn't in map + // will only happen in programmer error cases + c, _ := ce.Constructs.Get(urn) + c.Solution = sol +} + +func loadStateConverter() (stateconverter.StateConverter, error) { + templates, err := statetemplate.LoadStateTemplates("pulumi") + if err != nil { + return nil, err + } + return stateconverter.NewStateConverter("pulumi", templates), nil +} diff --git a/pkg/k2/constructs/input_resolver.go b/pkg/k2/constructs/input_resolver.go deleted file mode 100644 index 39c4ffc91..000000000 --- a/pkg/k2/constructs/input_resolver.go +++ /dev/null @@ -1,113 +0,0 @@ -package constructs - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/klothoplatform/klotho/pkg/k2/model" -) - -// ResolveInput converts a model.Input to a construct.Input and adds it to the inputs map. -// If the value of the input is a URN, it resolves the URN to a construct. -// If the input's status is not "resolved", it returns an error. -func (ce *ConstructEvaluator) ResolveInput(k string, v model.Input, t InputTemplate) (any, error) { - if v.Status != "" && v.Status != model.InputStatusResolved { - if ce.DryRun == model.DryRunNone { - return nil, fmt.Errorf("input '%s' is not resolved", k) - } - } - switch { - case strings.HasPrefix(t.Type, "Construct("): - cType := strings.TrimSuffix(strings.TrimPrefix(t.Type, "Construct("), ")") - - iURN, ok := v.Value.(model.URN) - if !ok { - urn, err := model.ParseURN(v.DependsOn) - if err != nil { - return nil, fmt.Errorf("input '%s' invalid DependsOn construct URN: %w", k, err) - } - iURN = *urn - } - - if iURN.IsResource() && iURN.Type == "construct" && iURN.Subtype == cType { - ic, ok := ce.Constructs.Get(iURN) - - if !ok { - return nil, fmt.Errorf("input '%s' could not find construct %s", k, iURN) - } - return ic, nil - } else { - return nil, fmt.Errorf("input '%s' invalid construct URN: %+v", k, v) - } - case t.Type == "path": - var err error - pStr, ok := v.Value.(string) - if !ok { - return "", fmt.Errorf("input '%s' invalid path type: expected string, got %T", k, v.Value) - } - path, err := handlePathInput(pStr) - if err != nil { - return nil, fmt.Errorf("input '%s' could not handle path input: %w", k, err) - } - return path, nil - - case t.Type == "KeyValueList": - return handleKeyValueListInput(v.Value, t) - default: - return v.Value, nil - } -} - -// handleKeyValueListInput handles converts a map[string]interface{} to list of key-value pairs. -// Key and value field names are configurable in the input template. The default field names are "Key" and "Value". -func handleKeyValueListInput(value any, t InputTemplate) (any, error) { - if value == nil { - return nil, nil - } - - inputMap, ok := value.(map[string]any) - if !ok { - return nil, fmt.Errorf("expected input to be of type map[string]any, got %T", value) - } - - var keyValueList []any - - keyField := "Key" - valueField := "Value" - if kf, ok := t.Configuration["keyField"]; ok { - if kfs, ok := kf.(string); ok { - keyField = kfs - } - } - - if vf, ok := t.Configuration["valueField"]; ok { - if vfs, ok := vf.(string); ok { - valueField = vfs - } - } - - for key, val := range inputMap { - kvPair := map[string]any{ - keyField: key, - valueField: val, - } - keyValueList = append(keyValueList, kvPair) - } - - return keyValueList, nil -} - -func handlePathInput(value string) (string, error) { - if filepath.IsAbs(value) { - return value, nil - } - - // handle relative paths - wd, err := os.Getwd() - if err != nil { - return "", fmt.Errorf("could not get working directory") - } - return filepath.Join(wd, value), nil -} diff --git a/pkg/k2/constructs/interpolation.go b/pkg/k2/constructs/interpolation.go new file mode 100644 index 000000000..1233e98a9 --- /dev/null +++ b/pkg/k2/constructs/interpolation.go @@ -0,0 +1,413 @@ +package constructs + +import ( + "errors" + "fmt" + template2 "github.com/klothoplatform/klotho/pkg/k2/constructs/template" + "github.com/klothoplatform/klotho/pkg/k2/model" + "github.com/klothoplatform/klotho/pkg/reflectutil" + "go.uber.org/zap" + "reflect" + "regexp" + "strconv" + "strings" +) + +// Matches one or more interpolation groups in a string e.g., ${inputs:foo.bar}-baz-${resource:Boz} +var interpolationPattern = regexp.MustCompile(`\$\{([^:]+):([^}]+)}`) + +// Matches exactly one interpolation group e.g., ${inputs:foo.bar} +var isolatedInterpolationPattern = regexp.MustCompile(`^\$\{([^:]+):([^}]+)}$`) + +var spreadPattern = regexp.MustCompile(`\.\.\.}$`) + +// interpolateValue interpolates a value based on the context of the construct +// +// The format of a raw value is ${:} where prefix is the type of value to interpolate and key is the key to interpolate +// +// The key can be a path to a value in the context. +// For example, ${inputs:foo.bar} will interpolate the value of the key bar in the foo input. +// +// The target of a dot-separated path can be a map or a struct. +// The path can also include brackets to access an array or an element whose key contains a dot. +// For example, ${inputs:foo[0].bar} will interpolate the value of the key bar in the first element of the foo input array. +// +// The path can also include a spread operator to expand an array into the current array. +// For example, ${inputs:foo...} will expand the foo input array into the current array. +// +// A rawValue can contain a combination of interpolation expressions, literals, and go templates. +// For example, "${inputs:foo.bar}-baz-${resource:Boz}" is a valid rawValue. +func (ce *ConstructEvaluator) interpolateValue(dv *DynamicValueData, rawValue any) (any, error) { + if ref, ok := rawValue.(template2.ResourceRef); ok { + switch ref.Type { + case template2.ResourceRefTypeInterpolated: + return ce.interpolateValue(dv, ref.ResourceKey) + case template2.ResourceRefTypeTemplate: + ref.ConstructURN = dv.currentOwner.GetURN() + rk, err := ce.interpolateValue(dv, ref.ResourceKey) + if err != nil { + return nil, err + } + ref.ResourceKey = fmt.Sprintf("%s", rk) + return ref, nil + default: + return rawValue, nil + } + } + + v := reflectutil.GetConcreteElement(reflect.ValueOf(rawValue)) + if !v.IsValid() { + return rawValue, nil + } + rawValue = v.Interface() + + switch v.Kind() { + case reflect.String: + resolvedVal, err := ce.interpolateString(dv, v.String()) + if err != nil { + return nil, err + } + return resolvedVal, nil + case reflect.Slice: + length := v.Len() + var interpolated []any + for i := 0; i < length; i++ { + // handle spread operator by injecting the spread value into the array at the current index + originalValue := reflectutil.GetConcreteValue(v.Index(i)) + if originalString, ok := originalValue.(string); ok && spreadPattern.MatchString(originalString) { + unspreadPath := originalString[:len(originalString)-4] + "}" + spreadValue, err := ce.interpolateValue(dv, unspreadPath) + if err != nil { + return nil, err + } + + if spreadValue == nil { + continue + } + if reflect.TypeOf(spreadValue).Kind() != reflect.Slice { + return nil, errors.New("spread value must be a slice") + } + + for i := 0; i < reflect.ValueOf(spreadValue).Len(); i++ { + interpolated = append(interpolated, reflect.ValueOf(spreadValue).Index(i).Interface()) + } + continue + } + value, err := ce.interpolateValue(dv, v.Index(i).Interface()) + if err != nil { + return nil, err + } + interpolated = append(interpolated, value) + } + return interpolated, nil + case reflect.Map: + keys := v.MapKeys() + interpolated := make(map[string]any) + for _, k := range keys { + key, err := ce.interpolateValue(dv, k.Interface()) + if err != nil { + return nil, err + } + value, err := ce.interpolateValue(dv, v.MapIndex(k).Interface()) + if err != nil { + return nil, err + } + interpolated[fmt.Sprint(key)] = value + } + return interpolated, nil + case reflect.Struct: + // Create a new instance of the struct + newStruct := reflect.New(v.Type()).Elem() + + // Interpolate each field + for i := 0; i < v.NumField(); i++ { + fieldName := v.Type().Field(i).Name + fieldValue, err := ce.interpolateValue(dv, v.Field(i).Interface()) + if err != nil { + return nil, err + } + // Set the interpolated value to the field in the new struct + if fieldValue != nil { + newStruct.FieldByName(fieldName).Set(reflect.ValueOf(fieldValue)) + } + } + + // Return the new struct + return newStruct.Interface(), nil + default: + return rawValue, nil + } +} + +func (ce *ConstructEvaluator) interpolateString(dv *DynamicValueData, rawValue string) (any, error) { + // handle go template expressions + if strings.Contains(rawValue, "{{") { + ctx := DynamicValueContext{constructs: ce.Constructs} + err := ctx.ExecuteUnmarshal(rawValue, dv, &rawValue) + if err != nil { + return nil, err + } + } + + ps := dv.propertySource + if ps == nil { + ps = dv.currentOwner.GetPropertySource() + } + + // if the rawValue is an isolated interpolation expression, interpolate it and return the raw value + if isolatedInterpolationPattern.MatchString(rawValue) { + return ce.interpolateExpression(dv.currentOwner, ps, rawValue) + } + + var err error + + // Replace each match in the rawValue (mixed expressions are always interpolated as strings) + interpolated := interpolationPattern.ReplaceAllStringFunc(rawValue, func(match string) string { + var val any + val, err = ce.interpolateExpression(dv.currentOwner, ps, match) + return fmt.Sprint(val) + }) + if err != nil { + return nil, err + } + + return interpolated, nil +} + +func (ce *ConstructEvaluator) interpolateExpression(owner InfraOwner, ps *template2.PropertySource, match string) (any, error) { + if ps == nil { + return nil, errors.New("property source is nil") + } + + // Split the match into prefix and key + parts := interpolationPattern.FindStringSubmatch(match) + prefix := parts[1] + key := parts[2] + + // Choose the correct root property from the source based on the prefix + var p any + ok := false + if prefix == "inputs" || prefix == "resources" || prefix == "edges" || prefix == "meta" || + strings.HasPrefix(prefix, "from.") || + strings.HasPrefix(prefix, "to.") { + p, ok = ps.GetProperty(prefix) + if !ok { + return nil, fmt.Errorf("could not get %s", prefix) + } + } else { + return nil, fmt.Errorf("invalid prefix: %s", prefix) + } + + prefixParts := strings.Split(prefix, ".") + + // associate any ResourceRefs with the URN of the property source they're being interpolated from + // if the prefix is "from" or "to", the URN of the property source is the "urn" field of that level in the property source + var refUrn model.URN + + if strings.HasSuffix(prefix, "resources") { + urnKey := "urn" + if prefixParts[0] == "from" || prefixParts[0] == "to" { + urnKey = fmt.Sprintf("%s.urn", prefixParts[0]) + } + psURN, ok := template2.GetTypedProperty[model.URN](ps, urnKey) + if !ok { + psURN = owner.GetURN() + } + refUrn = psURN + } else { + propTrace, err := reflectutil.TracePath(reflect.ValueOf(p), key) + if err == nil { + refConstruct, ok := reflectutil.LastOfType[*Construct](propTrace) + if ok { + refUrn = refConstruct.URN + } + } + if refUrn.Equals(model.URN{}) { + refUrn = owner.GetURN() + } + } + + // return an IaC reference if the key matches the IaC reference pattern + if iacRefPattern.MatchString(key) { + return template2.ResourceRef{ + ResourceKey: iacRefPattern.FindStringSubmatch(key)[1], + Property: iacRefPattern.FindStringSubmatch(key)[2], + Type: template2.ResourceRefTypeIaC, + ConstructURN: refUrn, + }, nil + } + + // special cases for resources allowing for accessing the name of a resource directly instead of using .Id.Name + if prefix == "resources" || prefixParts[len(prefixParts)-1] == "resources" { + keyParts := reflectutil.SplitPath(key) + resourceKey := strings.Trim(keyParts[0], ".[]") + if len(keyParts) > 1 { + if path := keyParts[1]; path == ".Name" { + return p.(map[string]*Resource)[resourceKey].Id.Name, nil + } + + } + } + + // Retrieve the value from the designated property source + value, err := ce.getValueFromSource(p, key, false) + if err != nil { + zap.S().Debugf("could not get value from source: %s", err) + return nil, nil + } + + keyAndRef := strings.Split(key, "#") + var refProperty string + if len(keyAndRef) == 2 { + refProperty = keyAndRef[1] + } + + // If the value is a Resource, return a ResourceRef + if r, ok := value.(*Resource); ok { + return template2.ResourceRef{ + ResourceKey: r.Id.String(), + Property: refProperty, + Type: template2.ResourceRefTypeIaC, + ConstructURN: refUrn, + }, nil + } + + if r, ok := value.(template2.ResourceRef); ok { + r.ConstructURN = refUrn + return r, nil + } + + // Replace the match with the value + return value, nil +} + +// iacRefPattern is a regular expression pattern that matches an IaC reference +// IaC references are in the format # +var iacRefPattern = regexp.MustCompile(`^([a-zA-Z0-9_-]+)#([a-zA-Z0-9._-]+)$`) + +// indexPattern is a regular expression pattern that matches an array index in the format `[index]` +var indexPattern = regexp.MustCompile(`^\[\d+]$`) + +// getValueFromSource retrieves a value from a property source based on a key +// the flat parameter is used to determine if the key is a flat key or a path (mixed keys aren't supported at the moment) +// e.g (flat = true): key = "foo.bar" -> value = collection["foo."bar"], flat = false: key = "foo.bar" -> value = collection["foo"]["bar"] +func (ce *ConstructEvaluator) getValueFromSource(source any, key string, flat bool) (any, error) { + value := reflect.ValueOf(source) + + keyAndRef := strings.Split(key, "#") + if len(keyAndRef) > 2 { + return nil, fmt.Errorf("invalid engine reference property reference: %s", key) + } + + var refProperty string + if len(keyAndRef) == 2 { + refProperty = keyAndRef[1] + key = keyAndRef[0] + } + + // Split the key into parts if not flat + parts := []string{key} + if !flat { + parts = reflectutil.SplitPath(key) + } + for i, part := range parts { + parts[i] = strings.TrimPrefix(part, ".") + } + + var err error + var lastValidValue reflect.Value + lastValidIndex := -1 + + // Traverse the map/struct/array according to the parts + for i, part := range parts { + // Check if the part is an array index + if indexPattern.MatchString(part) { + // Split the part into the key and the index + part = strings.TrimSuffix(strings.TrimPrefix(part, "["), "]") + var index int + index, err = strconv.Atoi(part) + if err != nil { + err = fmt.Errorf("could not parse index: %w", err) + break + } + + value = reflectutil.GetConcreteElement(value) + kind := value.Kind() + + switch kind { + case reflect.Slice | reflect.Array: + if index >= value.Len() { + err = fmt.Errorf("index out of bounds: %d", index) + break + } + value = value.Index(index) + default: + err = fmt.Errorf("invalid type: %s", kind) + } + } else { + // The part is not an array index + part = strings.TrimSuffix(strings.TrimPrefix(part, "["), "]") + + if value.Kind() == reflect.Map { + v := value.MapIndex(reflect.ValueOf(part)) + if v.IsValid() { + value = v + } else { + err = fmt.Errorf("could not get value for key: %s", key) + break + } + } else if r, ok := value.Interface().(*Resource); ok { + if len(parts) == 1 { + return template2.ResourceRef{ + ResourceKey: part, + Property: refProperty, + Type: template2.ResourceRefTypeTemplate, + }, nil + } else { + // if the parent is a resource, children are implicitly properties of the resource + lastValidValue = reflect.ValueOf(r.Properties) + value, err = reflectutil.GetField(lastValidValue, part) + if err != nil { + err = fmt.Errorf("could not get field: %w", err) + break + } + } + } else if u, ok := value.Interface().(model.URN); ok { + if c, ok := ce.Constructs.Get(u); ok { + lastValidValue = reflect.ValueOf(c) + value, err = reflectutil.GetField(lastValidValue, part) + if err != nil { + err = fmt.Errorf("could not get field: %w", err) + break + } + } else { + err = fmt.Errorf("could not get construct: %s", u) + break + } + } else { + var rVal reflect.Value + rVal, err = reflectutil.GetField(value, part) + if err != nil { + err = fmt.Errorf("could not get field: %w", err) + break + } + value = rVal + } + } + if err != nil { + break + } + if i == len(parts)-1 { + return value.Interface(), nil + } + + lastValidValue = value + lastValidIndex = i + } + + if lastValidIndex > -1 { + return ce.getValueFromSource(lastValidValue.Interface(), strings.Join(parts[lastValidIndex+1:], "."), true) + } + + return value.Interface(), err +} diff --git a/pkg/k2/constructs/template/binding_template.go b/pkg/k2/constructs/template/binding_template.go new file mode 100644 index 000000000..220c9fd8a --- /dev/null +++ b/pkg/k2/constructs/template/binding_template.go @@ -0,0 +1,71 @@ +package template + +import ( + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "gopkg.in/yaml.v3" +) + +type BindingTemplate struct { + From property.ConstructType `yaml:"from"` + To property.ConstructType `yaml:"to"` + Priority int `yaml:"priority"` + Inputs *Properties `yaml:"inputs"` + Outputs map[string]OutputTemplate `yaml:"outputs"` + InputRules []InputRuleTemplate `yaml:"input_rules"` + Resources map[string]ResourceTemplate `yaml:"resources"` + Edges []EdgeTemplate `yaml:"edges"` + resourceOrder []string +} + +func (bt *BindingTemplate) GetInput(path string) property.Property { + return property.GetProperty(bt.Inputs.propertyMap, path) +} + +// ForEachInput walks the input properties of a construct template, +// including nested properties, and calls the given function for each input. +// If the function returns an error, the walk will stop and return that error. +// If the function returns [ErrStopWalk], the walk will stop and return nil. +func (bt *BindingTemplate) ForEachInput(c construct.Properties, f func(property.Property) error) error { + return bt.Inputs.ForEach(c, f) +} + +func (bt *BindingTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] { + return Iterator[string, ResourceTemplate]{ + source: bt.Resources, + order: bt.resourceOrder, + } +} + +func (bt *BindingTemplate) UnmarshalYAML(value *yaml.Node) error { + type bindingTemplate BindingTemplate + var template bindingTemplate + if err := value.Decode(&template); err != nil { + return err + } + resourceOrder, _ := captureYAMLKeyOrder(value, "resources") + template.resourceOrder = resourceOrder + + if template.Inputs == nil { + template.Inputs = NewProperties(nil) + } + + if template.Resources == nil { + template.Resources = make(map[string]ResourceTemplate) + } + + if template.Edges == nil { + template.Edges = make([]EdgeTemplate, 0) + } + + if template.Outputs == nil { + template.Outputs = make(map[string]OutputTemplate) + } + + if template.InputRules == nil { + template.InputRules = make([]InputRuleTemplate, 0) + } + + *bt = BindingTemplate(template) + return nil +} diff --git a/pkg/k2/constructs/construct_template.go b/pkg/k2/constructs/template/construct_template.go similarity index 54% rename from pkg/k2/constructs/construct_template.go rename to pkg/k2/constructs/template/construct_template.go index 941d60c3f..47c594251 100644 --- a/pkg/k2/constructs/construct_template.go +++ b/pkg/k2/constructs/template/construct_template.go @@ -1,34 +1,27 @@ -package constructs +package template import ( "errors" "fmt" - "strings" - "github.com/klothoplatform/klotho/pkg/construct" - "github.com/klothoplatform/klotho/pkg/k2/model" - + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" "gopkg.in/yaml.v3" + "regexp" ) type ( ConstructTemplate struct { - Id ConstructTemplateId `yaml:"id"` + Id property.ConstructType `yaml:"id"` Version string `yaml:"version"` Description string `yaml:"description"` Resources map[string]ResourceTemplate `yaml:"resources"` Edges []EdgeTemplate `yaml:"edges"` - Inputs map[string]InputTemplate `yaml:"inputs"` + Inputs *Properties `yaml:"inputs"` Outputs map[string]OutputTemplate `yaml:"outputs"` InputRules []InputRuleTemplate `yaml:"input_rules"` resourceOrder []string } - ConstructTemplateId struct { - Package string `yaml:"package"` - Name string `yaml:"name"` - } - ResourceTemplate struct { Type string `yaml:"type"` Name string `yaml:"name"` @@ -42,17 +35,6 @@ type ( Data construct.EdgeData `yaml:"data"` } - InputTemplate struct { - Name string `yaml:"name"` - Type string `yaml:"type"` - Description string `yaml:"description"` - Default any `yaml:"default"` - Secret bool `yaml:"secret"` - PulumiKey string `yaml:"pulumi_key"` - Validation ValidationTemplate `yaml:"validation"` - Configuration map[string]any `yaml:"configuration"` - } - OutputTemplate struct { Name string `yaml:"name"` Description string `yaml:"description"` @@ -60,19 +42,23 @@ type ( } InputRuleTemplate struct { - If string `yaml:"if"` - Then ConditionalExpressionTemplate `yaml:"then"` - Else ConditionalExpressionTemplate `yaml:"else"` + If string `yaml:"if"` + Then *ConditionalExpressionTemplate `yaml:"then"` + Else *ConditionalExpressionTemplate `yaml:"else"` + ForEach string `yaml:"for_each"` + Do *ConditionalExpressionTemplate `yaml:"do"` + Prefix string `yaml:"prefix"` } ConditionalExpressionTemplate struct { - Resources map[string]ResourceTemplate `yaml:"resources"` - Edges []EdgeTemplate `yaml:"edges"` - Outputs map[string]OutputTemplate `yaml:"outputs"` + Resources map[string]ResourceTemplate `yaml:"resources"` + Edges []EdgeTemplate `yaml:"edges"` + Outputs map[string]OutputTemplate `yaml:"outputs"` + Rules []InputRuleTemplate `yaml:"rules"` + resourceOrder []string } ValidationTemplate struct { - Required bool `yaml:"required"` MinLength int `yaml:"min_length"` MaxLength int `yaml:"max_length"` MinValue int `yaml:"min_value"` @@ -81,71 +67,12 @@ type ( Enum []string `yaml:"enum"` UniqueValues bool `yaml:"unique_values"` } - - BindingTemplate struct { - From ConstructTemplateId `yaml:"from"` - To ConstructTemplateId `yaml:"to"` - Priority int `yaml:"priority"` - Inputs map[string]InputTemplate `yaml:"inputs"` - Outputs map[string]OutputTemplate `yaml:"outputs"` - InputRules []InputRuleTemplate `yaml:"input_rules"` - Resources map[string]ResourceTemplate `yaml:"resources"` - Edges []EdgeTemplate `yaml:"edges"` - resourceOrder []string - } ) -func (c *ConstructTemplateId) UnmarshalYAML(value *yaml.Node) error { - // Split the value into parts - parts := strings.Split(value.Value, ".") - - // Check if there are at least two parts: package and name - if len(parts) < 2 { - return fmt.Errorf("invalid construct template id: %s", value.Value) - } - - // The name is the last part - c.Name = parts[len(parts)-1] - - // The package is all the parts except the last one, joined by a dot - c.Package = strings.Join(parts[:len(parts)-1], ".") - - return nil -} - -func ParseConstructTemplateId(id string) (ConstructTemplateId, error) { - // Parse a construct template id from a string - parts := strings.Split(id, ".") - if len(parts) < 2 { - return ConstructTemplateId{}, fmt.Errorf("invalid construct template id: %s", id) - } - return ConstructTemplateId{ - Package: strings.Join(parts[:len(parts)-1], "."), - Name: parts[len(parts)-1], - }, nil -} - -func (c *ConstructTemplateId) String() string { - return fmt.Sprintf("%s.%s", c.Package, c.Name) -} - -func (c *ConstructTemplateId) FromURN(urn model.URN) error { - if urn.Type != "construct" { - return fmt.Errorf("invalid urn type: %s", urn.Type) - } - - parts := strings.Split(urn.Subtype, ".") - if len(parts) < 2 { - return fmt.Errorf("invalid construct template id: %s", urn.Subtype) - } - - c.Package = strings.Join(parts[:len(parts)-1], ".") - c.Name = parts[len(parts)-1] - return nil -} +var interpolationPattern = regexp.MustCompile(`\$\{([^:]+):([^}]+)}`) func (e *EdgeTemplate) UnmarshalYAML(value *yaml.Node) error { - // Unmarshal the edge template from a yaml node + // Unmarshal the edge template from a YAML node var edge struct { From string `yaml:"from"` To string `yaml:"to"` @@ -184,14 +111,14 @@ func (e *EdgeTemplate) UnmarshalYAML(value *yaml.Node) error { return nil } -func (c *ConstructTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] { +func (ct *ConstructTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] { return Iterator[string, ResourceTemplate]{ - source: c.Resources, - order: c.resourceOrder, + source: ct.Resources, + order: ct.resourceOrder, } } -func (c *ConstructTemplate) UnmarshalYAML(value *yaml.Node) error { +func (ct *ConstructTemplate) UnmarshalYAML(value *yaml.Node) error { type constructTemplate ConstructTemplate var template constructTemplate if err := value.Decode(&template); err != nil { @@ -199,26 +126,18 @@ func (c *ConstructTemplate) UnmarshalYAML(value *yaml.Node) error { } resourceOrder, _ := captureYAMLKeyOrder(value, "resources") template.resourceOrder = resourceOrder - *c = ConstructTemplate(template) - return nil -} - -func (b *BindingTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] { - return Iterator[string, ResourceTemplate]{ - source: b.Resources, - order: b.resourceOrder, + if template.Inputs == nil { + template.Inputs = NewProperties(nil) + } + if template.Resources == nil { + template.Resources = make(map[string]ResourceTemplate) } -} -func (b *BindingTemplate) UnmarshalYAML(value *yaml.Node) error { - type bindingTemplate BindingTemplate - var template bindingTemplate - if err := value.Decode(&template); err != nil { - return err + if template.Outputs == nil { + template.Outputs = make(map[string]OutputTemplate) } - resourceOrder, _ := captureYAMLKeyOrder(value, "resources") - template.resourceOrder = resourceOrder - *b = BindingTemplate(template) + + *ct = ConstructTemplate(template) return nil } @@ -269,12 +188,12 @@ func (r *Iterator[K, V]) Next() (K, V, bool) { type IterFunc[K comparable, V any] func(K, V) error -var stopIteration = fmt.Errorf("stop iteration") +var StopIteration = fmt.Errorf("stop iteration") func (r *Iterator[K, V]) ForEach(f IterFunc[K, V]) { for key, resource, ok := r.Next(); ok; key, resource, ok = r.Next() { if err := f(key, resource); err != nil { - if errors.Is(err, stopIteration) { + if errors.Is(err, StopIteration) { return } } @@ -292,10 +211,86 @@ const ( BindingDirectionTo = "to" ) -func (c *ConstructTemplate) GetBindingTemplate(direction BindingDirection, other ConstructTemplateId) (BindingTemplate, error) { +func (ct *ConstructTemplate) GetBindingTemplate(direction BindingDirection, other property.ConstructType) (BindingTemplate, error) { if direction == BindingDirectionFrom { - return loadBindingTemplate(c.Id, c.Id, other) + return LoadBindingTemplate(ct.Id, ct.Id, other) } else { - return loadBindingTemplate(c.Id, other, c.Id) + return LoadBindingTemplate(ct.Id, other, ct.Id) + } +} + +func (cet *ConditionalExpressionTemplate) UnmarshalYAML(value *yaml.Node) error { + type conditionalExpressionTemplate ConditionalExpressionTemplate + + var temp conditionalExpressionTemplate + + if err := value.Decode(&temp); err != nil { + return err + } + + cet.Resources = temp.Resources + cet.Edges = temp.Edges + cet.Outputs = temp.Outputs + cet.Rules = temp.Rules + + resourceOrder, _ := captureYAMLKeyOrder(value, "resources") + cet.resourceOrder = resourceOrder + + return nil +} + +func (cet *ConditionalExpressionTemplate) ResourcesIterator() Iterator[string, ResourceTemplate] { + return Iterator[string, ResourceTemplate]{ + source: cet.Resources, + order: cet.resourceOrder, } } + +func (irt *InputRuleTemplate) UnmarshalYAML(value *yaml.Node) error { + type inputRuleTemplate InputRuleTemplate + + var temp inputRuleTemplate + + if err := value.Decode(&temp); err != nil { + return err + } + + if (temp.If == "" && temp.ForEach == "") || (temp.If != "" && temp.ForEach != "") { + return fmt.Errorf("invalid InputRuleTemplate: must have either If-Then-Else or ForEach-Do") + } + + // Check if it's an If-Then-Else structure + if temp.If != "" { + if temp.ForEach != "" || temp.Do != nil { + return fmt.Errorf("invalid InputRuleTemplate: cannot mix If-Then-Else with ForEach-Do") + } + irt.If = temp.If + irt.Then = temp.Then + irt.Else = temp.Else + } else if temp.ForEach != "" { + // Check if it's a ForEach-Do structure + if temp.If != "" || temp.Then != nil || temp.Else != nil { + return fmt.Errorf("invalid InputRuleTemplate: cannot mix ForEach-Do with If-Then-Else") + } + irt.ForEach = temp.ForEach + irt.Do = temp.Do + } else { + return fmt.Errorf("invalid InputRuleTemplate: must have either If-Then-Else or ForEach-Do") + } + + irt.Prefix = temp.Prefix + + return nil +} + +func (ct *ConstructTemplate) GetInput(path string) property.Property { + return property.GetProperty(ct.Inputs.propertyMap, path) +} + +// ForEachInput walks the input properties of a construct template, +// including nested properties, and calls the given function for each input. +// If the function returns an error, the walk will stop and return that error. +// If the function returns [ErrStopWalk], the walk will stop and return nil. +func (ct *ConstructTemplate) ForEachInput(c construct.Properties, f func(property.Property) error) error { + return ct.Inputs.ForEach(c, f) +} diff --git a/pkg/k2/constructs/construct_template_test.go b/pkg/k2/constructs/template/construct_template_test.go similarity index 82% rename from pkg/k2/constructs/construct_template_test.go rename to pkg/k2/constructs/template/construct_template_test.go index abb6bfbdd..26d19ffa8 100644 --- a/pkg/k2/constructs/construct_template_test.go +++ b/pkg/k2/constructs/template/construct_template_test.go @@ -1,10 +1,12 @@ -package constructs +package template import ( "testing" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/properties" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/klothoplatform/klotho/pkg/construct" - "github.com/klothoplatform/klotho/pkg/k2/model" "github.com/stretchr/testify/assert" "gopkg.in/yaml.v3" ) @@ -13,13 +15,13 @@ func TestUnmarshalConstructTemplateId(t *testing.T) { tests := []struct { name string input string - expected ConstructTemplateId + expected property.ConstructType wantErr bool }{ { name: "Valid ConstructTemplateId", input: "package.name", - expected: ConstructTemplateId{Package: "package", Name: "name"}, + expected: property.ConstructType{Package: "package", Name: "name"}, wantErr: false, }, { @@ -31,7 +33,7 @@ func TestUnmarshalConstructTemplateId(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var ctId ConstructTemplateId + var ctId property.ConstructType err := yaml.Unmarshal([]byte(tt.input), &ctId) if tt.wantErr { assert.Error(t, err) @@ -43,17 +45,17 @@ func TestUnmarshalConstructTemplateId(t *testing.T) { } } -func TestParseConstructTemplateId(t *testing.T) { +func TestParseConstructType(t *testing.T) { tests := []struct { name string input string - expected ConstructTemplateId + expected property.ConstructType wantErr bool }{ { name: "Valid ConstructTemplateId", input: "package.name", - expected: ConstructTemplateId{Package: "package", Name: "name"}, + expected: property.ConstructType{Package: "package", Name: "name"}, wantErr: false, }, { @@ -65,7 +67,7 @@ func TestParseConstructTemplateId(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctId, err := ParseConstructTemplateId(tt.input) + ctId, err := property.ParseConstructType(tt.input) if tt.wantErr { assert.Error(t, err) } else { @@ -76,54 +78,12 @@ func TestParseConstructTemplateId(t *testing.T) { } } -func TestConstructTemplateId_String(t *testing.T) { - ctId := ConstructTemplateId{Package: "package", Name: "name"} +func TestConstructType_String(t *testing.T) { + ctId := property.ConstructType{Package: "package", Name: "name"} expected := "package.name" assert.Equal(t, expected, ctId.String()) } -func TestConstructTemplateId_FromURN(t *testing.T) { - tests := []struct { - name string - input string - expected ConstructTemplateId - wantErr bool - }{ - { - name: "Valid URN", - input: "urn:accountid:project:dev::construct/package.name", - expected: ConstructTemplateId{Package: "package", Name: "name"}, - wantErr: false, - }, - { - name: "Invalid URN type", - input: "urn:accountid:project:dev::other/package.name", - wantErr: true, - }, - { - name: "Invalid URN format", - input: "urn:accountid:project:dev::construct/invalid", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var ctId ConstructTemplateId - urn, err := model.ParseURN(tt.input) - if assert.NoError(t, err) { - err = ctId.FromURN(*urn) - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expected, ctId) - } - } - }) - } -} - func TestEdgeTemplate_UnmarshalYAML(t *testing.T) { tests := []struct { name string @@ -192,7 +152,7 @@ inputs: name: "input1" type: "string" description: "An input" - default: "default" + default_value: "default" outputs: output1: name: "output1" @@ -219,7 +179,7 @@ input_rules: ` expected := ConstructTemplate{ - Id: ConstructTemplateId{Package: "package", Name: "name"}, + Id: property.ConstructType{Package: "package", Name: "name"}, Version: "1.0", Description: "A test template", Resources: map[string]ResourceTemplate{ @@ -233,24 +193,33 @@ input_rules: Data: construct.EdgeData{}, }, }, - Inputs: map[string]InputTemplate{ - "input1": {Name: "input1", Type: "string", Description: "An input", Default: "default"}, - }, + Inputs: NewProperties(property.PropertyMap{ + "input1": &properties.StringProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "input1", + Description: "An input", + Path: "input1", + }, + SharedPropertyFields: properties.SharedPropertyFields{DefaultValue: "default"}, + }, + }), Outputs: map[string]OutputTemplate{ "output1": {Name: "output1", Description: "An output", Value: "value1"}, }, InputRules: []InputRuleTemplate{ { If: "condition", - Then: ConditionalExpressionTemplate{ + Then: &ConditionalExpressionTemplate{ Resources: map[string]ResourceTemplate{ "res3": {Type: "type3", Name: "name3", Namespace: "namespace3", Properties: map[string]any{"prop3": "value3"}}, }, + resourceOrder: []string{"res3"}, }, - Else: ConditionalExpressionTemplate{ + Else: &ConditionalExpressionTemplate{ Resources: map[string]ResourceTemplate{ "res4": {Type: "type4", Name: "name4", Namespace: "namespace4", Properties: map[string]any{"prop4": "value4"}}, }, + resourceOrder: []string{"res4"}, }, }, }, diff --git a/pkg/k2/constructs/template/inputs/properties_template.go b/pkg/k2/constructs/template/inputs/properties_template.go new file mode 100644 index 000000000..9756ca0fa --- /dev/null +++ b/pkg/k2/constructs/template/inputs/properties_template.go @@ -0,0 +1,437 @@ +package inputs + +import ( + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/properties" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "gopkg.in/yaml.v3" +) + +type ( + // InputTemplateMap defines the structure of properties defined in YAML as a part of a template. + InputTemplateMap map[string]*InputTemplate + + // InputTemplate defines the structure of a property defined in YAML as a part of a template. + // these fields must be exactly the union of all the fields in the different property types. + InputTemplate struct { + Name string `json:"name" yaml:"name"` + // Type defines the type of the property + Type string `json:"type" yaml:"type"` + // Description defines the description of the property + Description string `json:"description" yaml:"description"` + // DefaultValue defines the default value of the property + DefaultValue any `json:"default_value" yaml:"default_value"` + // Required defines whether the property is required + Required bool `json:"required" yaml:"required"` + + // Properties defines the sub properties of a key_value_list, map, list, or set + Properties InputTemplateMap `json:"properties" yaml:"properties"` + + // MinLength defines the minimum length of a string, list, set, or map (number of entries) + MinLength *int `yaml:"min_length"` + // MaxLength defines the maximum length of a string, list, set, or map (number of entries) + MaxLength *int `yaml:"max_length"` + + // MinValue defines the minimum value of an int or float + MinValue *float64 `yaml:"min_value"` + // MaxValue defines the maximum value of an int or float + MaxValue *float64 `yaml:"max_value"` + + // UniqueItems defines whether the items in a list or set must be unique + UniqueItems *bool `yaml:"unique_items"` + // UniqueKeys defines whether the keys in a map must be unique (default true) + UniqueKeys *bool `yaml:"unique_keys"` + // SanitizeTmpl is a go template to sanitize user input when setting the property + SanitizeTmpl string `yaml:"sanitize"` + // AllowedValues defines an enumeration of allowed values for a string, int, float, or bool + AllowedValues []string `yaml:"allowed_values"` + + // KeyProperty is the property of the keys in a key_value_list or map + KeyProperty *InputTemplate `yaml:"key_property"` + // ValueProperty is the property of the values in a key_value_list or map + ValueProperty *InputTemplate `yaml:"value_property"` + + // ItemProperty is the property of the items in a list or set + ItemProperty *InputTemplate `yaml:"item_property"` + + // Path is the path to the property in the template + // this field is derived and is not part of the yaml + Path string `json:"-" yaml:"-"` + } + + PropertyType string + FieldConverterFunc func(val reflect.Value, p *InputTemplate, kp property.Property) error +) + +var ( + StringPropertyType PropertyType = "string" + IntPropertyType PropertyType = "int" + FloatPropertyType PropertyType = "float" + BoolPropertyType PropertyType = "bool" + MapPropertyType PropertyType = "map" + ListPropertyType PropertyType = "list" + SetPropertyType PropertyType = "set" + AnyPropertyType PropertyType = "any" + PathPropertyType PropertyType = "path" + KeyValueListPropertyType PropertyType = "key_value_list" + ConstructPropertyType PropertyType = "construct" +) + +func (p *InputTemplateMap) UnmarshalYAML(n *yaml.Node) error { + type h InputTemplateMap + var p2 h + err := n.Decode(&p2) + if err != nil { + return err + } + for name, property := range p2 { + property.Name = name + property.Path = name + setChildPaths(property, name) + p2[name] = property + } + *p = InputTemplateMap(p2) + return nil +} + +func (p *InputTemplateMap) Convert() (property.PropertyMap, error) { + var errs error + props := property.PropertyMap{} + for name, prop := range *p { + propertyType, err := prop.Convert() + if err != nil { + errs = fmt.Errorf("%w\n%s", errs, err.Error()) + continue + } + props[name] = propertyType + } + return props, errs +} + +func (p *InputTemplate) Convert() (property.Property, error) { + propertyType, err := InitializeProperty(p.Type) + if err != nil { + return nil, err + } + propertyType.Details().Path = p.Path + + srcVal := reflect.ValueOf(p).Elem() + dstVal := reflect.ValueOf(propertyType).Elem() + for i := 0; i < srcVal.NumField(); i++ { + srcField := srcVal.Field(i) + fieldName := srcVal.Type().Field(i).Name + dstField := dstVal.FieldByName(fieldName) + if !dstField.IsValid() || !dstField.CanSet() { + continue + } + // Skip nil pointers + if (srcField.Kind() == reflect.Ptr || srcField.Kind() == reflect.Interface) && srcField.IsNil() { + continue + // skip empty arrays and slices + } else if (srcField.Kind() == reflect.Array || srcField.Kind() == reflect.Slice) && srcField.Len() == 0 { + continue + } + // Handle sub properties so we can recurse down the tree + switch fieldName { + case "Properties": + propMap := srcField.Interface().(InputTemplateMap) + var errs error + props := property.PropertyMap{} + for name, prop := range propMap { + propertyType, err := prop.Convert() + if err != nil { + errs = fmt.Errorf("%w\n%s", errs, err.Error()) + continue + } + props[name] = propertyType + } + if errs != nil { + return nil, fmt.Errorf("could not convert sub properties: %w", errs) + } + dstField.Set(reflect.ValueOf(props)) + continue + + case "KeyProperty", "ValueProperty": + switch { + case strings.HasPrefix(p.Type, "map"): + keyType, valueType, hasElementTypes := strings.Cut( + strings.TrimSuffix(strings.TrimPrefix(p.Type, "map("), ")"), + ",", + ) + elemProp := srcField.Interface().(*InputTemplate) + // Add the element's type if it is not specified but is on the parent. + // For example, 'map(string,string)' on the parent means the key_property doesn't need 'type: string' + if hasElementTypes { + if fieldName == "KeyProperty" { + if elemProp.Type != "" && elemProp.Type != keyType { + return nil, fmt.Errorf("key property type must be %s (was %s)", keyType, elemProp.Type) + } else if elemProp.Type == "" { + elemProp.Type = keyType + } + } else { + if elemProp.Type != "" && elemProp.Type != valueType { + return nil, fmt.Errorf("value property type must be %s (was %s)", valueType, elemProp.Type) + } else if elemProp.Type == "" { + elemProp.Type = valueType + } + } + } + converted, err := elemProp.Convert() + if err != nil { + return nil, fmt.Errorf("could not convert %s: %w", fieldName, err) + } + srcField = reflect.ValueOf(converted) + case strings.HasPrefix(p.Type, "key_value_list"): + keyType, valueType, hasElementTypes := strings.Cut( + strings.TrimSuffix(strings.TrimPrefix(p.Type, "key_value_list("), ")"), + ",", + ) + keyType = strings.TrimSpace(keyType) + valueType = strings.TrimSpace(valueType) + elemProp := srcField.Interface().(*InputTemplate) + if hasElementTypes { + if fieldName == "KeyProperty" { + if elemProp.Type != "" && elemProp.Type != keyType { + return nil, fmt.Errorf("key property type must be %s (was %s)", keyType, elemProp.Type) + } else if elemProp.Type == "" { + elemProp.Type = keyType + } + } else { + if elemProp.Type != "" && elemProp.Type != valueType { + return nil, fmt.Errorf("value property type must be %s (was %s)", valueType, elemProp.Type) + } else if elemProp.Type == "" { + elemProp.Type = valueType + } + } + } + converted, err := elemProp.Convert() + if err != nil { + return nil, fmt.Errorf("could not convert %s: %w", fieldName, err) + } + srcField = reflect.ValueOf(converted) + default: + return nil, fmt.Errorf("property must be 'map' or 'key_value_list' (was %s) for %s", p.Type, fieldName) + } + case "ItemProperty": + hasItemType := strings.Contains(p.Type, "(") + elemProp := srcField.Interface().(*InputTemplate) + if hasItemType { + itemType := strings.TrimSuffix( + strings.TrimPrefix(strings.TrimPrefix(p.Type, "list("), "set("), + ")", + ) + if elemProp.Type != "" && elemProp.Type != itemType { + return nil, fmt.Errorf("item property type must be %s (was %s)", itemType, elemProp.Type) + } else if elemProp.Type == "" { + elemProp.Type = itemType + } + } + converted, err := elemProp.Convert() + if err != nil { + return nil, fmt.Errorf("could not convert %s: %w", fieldName, err) + } + srcField = reflect.ValueOf(converted) + } + + if srcField.Type().AssignableTo(dstField.Type()) { + dstField.Set(srcField) + continue + } + + if dstField.Kind() == reflect.Ptr && srcField.Kind() == reflect.Ptr { + if srcField.Type().Elem().AssignableTo(dstField.Type().Elem()) { + dstField.Set(srcField) + continue + } else if srcField.Type().Elem().ConvertibleTo(dstField.Type().Elem()) { + val := srcField.Elem().Convert(dstField.Type().Elem()) + // set dest field to a pointer of val + dstField.Set(reflect.New(dstField.Type().Elem())) + dstField.Elem().Set(val) + continue + } + } + + if conversion, found := fieldConversion[fieldName]; found { + err := conversion(srcField, p, propertyType) + if err != nil { + return nil, err + } + continue + } + + return nil, fmt.Errorf( + "could not assign %s#%s (%s) to field in %T (%s)", + p.Path, fieldName, srcField.Type(), propertyType, dstField.Type(), + ) + + } + + return propertyType, nil +} + +func setChildPaths(property *InputTemplate, currPath string) { + for name, child := range property.Properties { + child.Name = name + path := currPath + "." + name + child.Path = path + setChildPaths(child, path) + } +} + +func (p InputTemplateMap) Clone() InputTemplateMap { + newProps := make(InputTemplateMap, len(p)) + for k, v := range p { + newProps[k] = v.Clone() + } + return newProps +} + +func (p *InputTemplate) Clone() *InputTemplate { + cloned := *p + cloned.Properties = make(InputTemplateMap, len(p.Properties)) + for k, v := range p.Properties { + cloned.Properties[k] = v.Clone() + } + return &cloned +} + +// fieldConversion is a map providing functionality on how to convert inputs into our internal types if they are not inherently the same structure +var fieldConversion = map[string]FieldConverterFunc{ + "SanitizeTmpl": func(val reflect.Value, p *InputTemplate, kp property.Property) error { + sanitizeTmpl, ok := val.Interface().(string) + if !ok { + return fmt.Errorf("invalid sanitize template") + } + if sanitizeTmpl == "" { + return nil + } + tmpl, err := property.NewSanitizationTmpl(kp.Details().Name, sanitizeTmpl) + if err != nil { + return err + } + dstField := reflect.ValueOf(kp).Elem().FieldByName("SanitizeTmpl") + dstField.Set(reflect.ValueOf(tmpl)) + return nil + }, +} + +func InitializeProperty(ptype string) (property.Property, error) { + if ptype == "" { + return nil, fmt.Errorf("property does not have a type") + } + baseType, typeArgs, err := GetTypeInfo(ptype) + if err != nil { + return nil, err + } + switch baseType { + case MapPropertyType: + if len(typeArgs) == 0 { + return &properties.MapProperty{}, nil + } + if len(typeArgs) != 2 { + return nil, fmt.Errorf("invalid number of arguments for map property type: %s", ptype) + } + keyVal, err := InitializeProperty(typeArgs[0]) + if err != nil { + return nil, err + } + valProp, err := InitializeProperty(typeArgs[1]) + if err != nil { + return nil, err + } + return &properties.MapProperty{KeyProperty: keyVal, ValueProperty: valProp}, nil + case ListPropertyType: + if len(typeArgs) == 0 { + return &properties.ListProperty{}, nil + } + if len(typeArgs) != 1 { + return nil, fmt.Errorf("invalid number of arguments for list property type: %s", ptype) + } + itemProp, err := InitializeProperty(typeArgs[0]) + if err != nil { + return nil, err + } + return &properties.ListProperty{ItemProperty: itemProp}, nil + case SetPropertyType: + if len(typeArgs) == 0 { + return &properties.SetProperty{}, nil + } + if len(typeArgs) != 1 { + return nil, fmt.Errorf("invalid number of arguments for set property type: %s", ptype) + } + itemProp, err := InitializeProperty(typeArgs[0]) + if err != nil { + return nil, err + } + return &properties.SetProperty{ItemProperty: itemProp}, nil + case KeyValueListPropertyType: + if len(typeArgs) == 0 { + return &properties.KeyValueListProperty{}, nil + } + if len(typeArgs) != 2 { + return nil, fmt.Errorf("invalid number of arguments for %s property type: %s", KeyValueListPropertyType, ptype) + } + keyPropType := typeArgs[0] + valPropType := typeArgs[1] + keyProp, err := InitializeProperty(keyPropType) + keyProp.Details().Name = "Key" + if err != nil { + return nil, err + } + valProp, err := InitializeProperty(valPropType) + valProp.Details().Name = "Value" + if err != nil { + return nil, err + } + return &properties.KeyValueListProperty{KeyProperty: keyProp, ValueProperty: valProp}, nil + case ConstructPropertyType: + var allowedTypes []property.ConstructType + if len(typeArgs) > 0 { + for _, t := range typeArgs { + var id property.ConstructType + err := id.FromString(t) + if err != nil { + return nil, fmt.Errorf("invalid construct type %s: %w", t, err) + } + allowedTypes = append(allowedTypes, id) + } + } + return &properties.ConstructProperty{AllowedTypes: allowedTypes}, nil + case AnyPropertyType: + return &properties.AnyProperty{}, nil + case StringPropertyType: + return &properties.StringProperty{}, nil + case IntPropertyType: + return &properties.IntProperty{}, nil + case FloatPropertyType: + return &properties.FloatProperty{}, nil + case BoolPropertyType: + return &properties.BoolProperty{}, nil + case PathPropertyType: + return &properties.PathProperty{}, nil + default: + return nil, fmt.Errorf("unknown property type '%s'", baseType) + } + +} + +var funcRegex = regexp.MustCompile(`^(\w+)(?:\(([^)]*)\))?$`) +var argRegex = regexp.MustCompile(`[^,]+`) + +func GetTypeInfo(t string) (propType PropertyType, args []string, err error) { + matches := funcRegex.FindStringSubmatch(t) + if matches == nil { + return "", nil, fmt.Errorf("invalid property type %s", t) + } + propType = PropertyType(matches[1]) + args = argRegex.FindAllString(matches[2], -1) + for i, arg := range args { + args[i] = strings.TrimSpace(arg) + } + + return propType, args, nil +} diff --git a/pkg/k2/constructs/template/inputs/properties_template_test.go b/pkg/k2/constructs/template/inputs/properties_template_test.go new file mode 100644 index 000000000..7cb57b666 --- /dev/null +++ b/pkg/k2/constructs/template/inputs/properties_template_test.go @@ -0,0 +1,251 @@ +package inputs + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/properties" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ConvertProperty(t *testing.T) { + tests := []struct { + name string + property InputTemplate + expected property.Property + }{ + { + name: "Convert string property type", + property: InputTemplate{ + Type: "string", + Name: "test", + Path: "test", + Required: true, + AllowedValues: []string{"test1", "test2"}, + DefaultValue: "test", + }, + expected: &properties.StringProperty{ + SharedPropertyFields: properties.SharedPropertyFields{ + DefaultValue: "test", + }, + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + AllowedValues: []string{"test1", "test2"}, + }, + }, + { + name: "Convert int property type", + property: InputTemplate{ + Type: "int", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.IntProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + }, + }, + { + name: "Convert float property type", + property: InputTemplate{ + Type: "float", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.FloatProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + }, + }, + { + name: "Convert bool property type", + property: InputTemplate{ + Type: "bool", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.BoolProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + }, + }, + { + name: "Convert map property type", + property: InputTemplate{ + Type: "map(string,string)", + Name: "test", + Path: "test", + Required: true, + KeyProperty: &InputTemplate{ + Type: "string", + }, + ValueProperty: &InputTemplate{ + Type: "string", + }, + }, + expected: &properties.MapProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + Properties: map[string]property.Property{}, + KeyProperty: &properties.StringProperty{}, + ValueProperty: &properties.StringProperty{}, + }, + }, + { + name: "Convert list property type", + property: InputTemplate{ + Type: "list(string)", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.ListProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + ItemProperty: &properties.StringProperty{}, + Properties: map[string]property.Property{}, + }, + }, + { + name: "Convert set property type", + property: InputTemplate{ + Type: "set(string)", + Name: "test", + Path: "test", + Required: true, + ItemProperty: &InputTemplate{ + Type: "string", + }, + }, + expected: &properties.SetProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + ItemProperty: &properties.StringProperty{}, + Properties: map[string]property.Property{}, + }, + }, + { + name: "Convert key_value_list property type", + property: InputTemplate{ + Type: "key_value_list(string,string)", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + KeyProperty: &properties.StringProperty{PropertyDetails: property.PropertyDetails{Name: "Key"}}, + ValueProperty: &properties.StringProperty{PropertyDetails: property.PropertyDetails{Name: "Value"}}, + }, + }, + { + name: "Convert key_value_list property type with custom key and value properties", + property: InputTemplate{ + Type: "key_value_list(string,string)", + Name: "test", + Path: "test", + KeyProperty: &InputTemplate{ + Type: "string", + Name: "CustomKeyKey", + }, + ValueProperty: &InputTemplate{ + Type: "string", + Name: "CustomValueKey", + }, + }, + expected: &properties.KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Path: "test", + }, + KeyProperty: &properties.StringProperty{PropertyDetails: property.PropertyDetails{Name: "CustomKeyKey"}}, + ValueProperty: &properties.StringProperty{PropertyDetails: property.PropertyDetails{Name: "CustomValueKey"}}, + }, + }, + { + name: "Convert construct property type", + property: InputTemplate{ + Type: "construct", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.ConstructProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + }, + }, + { + name: "Convert any property type", + property: InputTemplate{ + Type: "any", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.AnyProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + }, + }, + { + name: "Convert path property type", + property: InputTemplate{ + Type: "path", + Name: "test", + Path: "test", + Required: true, + }, + expected: &properties.PathProperty{ + PropertyDetails: property.PropertyDetails{ + Name: "test", + Required: true, + Path: "test", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actual, err := test.property.Convert() + require.NoError(t, err) + assert.EqualValuesf(t, actual, test.expected, "expected %v, got %v", test.expected, actual) + }) + } +} diff --git a/pkg/k2/constructs/template/properties.go b/pkg/k2/constructs/template/properties.go new file mode 100644 index 000000000..31514f516 --- /dev/null +++ b/pkg/k2/constructs/template/properties.go @@ -0,0 +1,68 @@ +package template + +import ( + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/inputs" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "gopkg.in/yaml.v3" +) + +type Properties struct { + propertyMap property.PropertyMap +} + +func NewProperties(properties property.PropertyMap) *Properties { + if properties == nil { + properties = make(property.PropertyMap) + } + + return &Properties{ + propertyMap: properties, + } +} + +func (p *Properties) Clone() property.Properties { + newProps := Properties{ + propertyMap: p.propertyMap.Clone(), + } + return &newProps +} + +func (p *Properties) ForEach(c construct.Properties, f func(p property.Property) error) error { + return p.propertyMap.ForEach(c, f) +} + +func (p *Properties) Get(key string) (property.Property, bool) { + return p.propertyMap.Get(key) +} + +func (p *Properties) Set(key string, value property.Property) { + p.propertyMap.Set(key, value) +} + +func (p *Properties) Remove(key string) { + p.propertyMap.Remove(key) +} + +func (p *Properties) AsMap() map[string]property.Property { + return p.propertyMap +} + +func (p *Properties) UnmarshalYAML(node *yaml.Node) error { + if p.propertyMap == nil { + p.propertyMap = make(property.PropertyMap) + } + + ip := make(inputs.InputTemplateMap) + if err := node.Decode(&ip); err != nil { + return err + } + converted, err := ip.Convert() + if err != nil { + return err + } + for k, v := range converted { + p.propertyMap[k] = v + } + return nil +} diff --git a/pkg/k2/constructs/template/properties/any_property.go b/pkg/k2/constructs/template/properties/any_property.go new file mode 100644 index 000000000..96af162bd --- /dev/null +++ b/pkg/k2/constructs/template/properties/any_property.go @@ -0,0 +1,108 @@ +package properties + +import ( + "fmt" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" +) + +type ( + AnyProperty struct { + SharedPropertyFields + property.PropertyDetails + } +) + +func (a *AnyProperty) SetProperty(properties construct.Properties, value any) error { + return properties.SetProperty(a.Path, value) +} + +func (a *AnyProperty) AppendProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(a.Path) + if err != nil { + return err + } + if propVal == nil { + return a.SetProperty(properties, value) + } + return properties.AppendProperty(a.Path, value) +} + +func (a *AnyProperty) RemoveProperty(properties construct.Properties, value any) error { + return properties.RemoveProperty(a.Path, value) +} + +func (a *AnyProperty) Details() *property.PropertyDetails { + return &a.PropertyDetails +} + +func (a *AnyProperty) Clone() property.Property { + clone := *a + return &clone +} + +func (a *AnyProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if a.DefaultValue == nil { + return nil, nil + } + return a.Parse(a.DefaultValue, ctx, data) +} + +func (a *AnyProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + if val, ok := value.(string); ok { + // check if its any other template string + var result any + err := ctx.ExecuteUnmarshal(val, data, &result) + if err == nil { + return result, nil + } + } + + if mapVal, ok := value.(map[string]any); ok { + m := MapProperty{KeyProperty: &StringProperty{}, ValueProperty: &AnyProperty{}} + return m.Parse(mapVal, ctx, data) + } + + if listVal, ok := value.([]any); ok { + l := ListProperty{ItemProperty: &AnyProperty{}} + return l.Parse(listVal, ctx, data) + } + + return value, nil +} + +func (a *AnyProperty) ZeroValue() any { + return nil +} + +func (a *AnyProperty) Contains(value any, contains any) bool { + if val, ok := value.(string); ok { + s := StringProperty{} + return s.Contains(val, contains) + } + if mapVal, ok := value.(map[string]any); ok { + m := MapProperty{KeyProperty: &StringProperty{}, ValueProperty: &AnyProperty{}} + return m.Contains(mapVal, contains) + } + if listVal, ok := value.([]any); ok { + l := ListProperty{ItemProperty: &AnyProperty{}} + return l.Contains(listVal, contains) + } + return false +} + +func (a *AnyProperty) Type() string { + return "any" +} + +func (a *AnyProperty) Validate(properties construct.Properties, value any) error { + if a.Required && value == nil { + return fmt.Errorf(property.ErrRequiredProperty, a.Path) + } + return nil +} + +func (a *AnyProperty) SubProperties() property.PropertyMap { + return nil +} diff --git a/pkg/k2/constructs/template/properties/any_property_test.go b/pkg/k2/constructs/template/properties/any_property_test.go new file mode 100644 index 000000000..d7263de2a --- /dev/null +++ b/pkg/k2/constructs/template/properties/any_property_test.go @@ -0,0 +1,230 @@ +package properties + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" +) + +// Testing the SetProperty method for different cases +func Test_AnyProperty_SetProperty(t *testing.T) { + tests := []struct { + name string + property *AnyProperty + input any + wantError bool + }{ + { + name: "valid string value", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: "valid_string", + wantError: false, + }, + { + name: "valid map value", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: map[string]any{"key1": "value1", "key2": "value2"}, + wantError: false, + }, + { + name: "valid list value", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: []any{"item1", "item2"}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the ZeroValue method +func Test_AnyProperty_ZeroValue(t *testing.T) { + assert := assert.New(t) + property := &AnyProperty{} + assert.Nil(property.ZeroValue()) +} + +// Testing the Details method +func Test_AnyProperty_Details(t *testing.T) { + assert := assert.New(t) + property := &AnyProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +// Testing the Clone method +func Test_AnyProperty_Clone(t *testing.T) { + property := &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + } + clone := property.Clone() + assert.Equal(t, property, clone) +} + +// Testing the AppendProperty method with different cases +func Test_AnyProperty_AppendProperty(t *testing.T) { + tests := []struct { + name string + property *AnyProperty + properties construct.Properties + input any + expected any + wantError bool + }{ + { + name: "append to empty property", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: map[string]any{"key1": "value1"}, + wantError: true, + }, + { + name: "append to existing map property", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{ + "test": map[string]any{"key1": "value1"}, + }, + input: map[string]any{"key2": "value2"}, + expected: map[string]any{"key1": "value1", "key2": "value2"}, + wantError: false, + }, + { + name: "append invalid type to map property", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{ + "test": map[string]any{"key1": "value1"}, + }, + input: "invalid_value", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.AppendProperty(tt.properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, tt.properties[tt.property.Path]) + } + }) + } +} + +// Testing the RemoveProperty method +func Test_AnyProperty_RemoveProperty(t *testing.T) { + tests := []struct { + name string + property *AnyProperty + properties construct.Properties + input any + expected any + wantError bool + }{ + { + name: "remove existing map entry", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{ + "test": map[string]any{"key1": "value1", "key2": "value2"}, + }, + input: "key1", + expected: map[string]any{"key2": "value2"}, + wantError: true, // DS - not sure if this is correct or an existing bug + }, + { + name: "remove existing list entry", + property: &AnyProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{ + "test": []any{"item1", "item2"}, + }, + input: "item1", + expected: []any{"item2"}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.RemoveProperty(tt.properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, tt.properties[tt.property.Path]) + } + }) + } +} + +// Testing the Parse method +func Test_AnyProperty_Parse(t *testing.T) { + tests := []struct { + name string + property *AnyProperty + input any + expected any + wantError bool + }{ + { + name: "parse string template", + property: &AnyProperty{}, + input: "{{ toUpper \"VALUE\" }}", + expected: "VALUE", + wantError: false, + }, + { + name: "parse map value", + property: &AnyProperty{}, + input: map[string]any{"key1": "value1", "key2": "value2"}, + expected: map[string]any{"key1": "value1", "key2": "value2"}, + wantError: false, + }, + { + name: "parse list value", + property: &AnyProperty{}, + input: []any{"item1", "item2"}, + expected: []any{"item1", "item2"}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.property.Parse(tt.input, DefaultExecutionContext{}, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} diff --git a/pkg/k2/constructs/template/properties/bool_property.go b/pkg/k2/constructs/template/properties/bool_property.go new file mode 100644 index 000000000..f62078c57 --- /dev/null +++ b/pkg/k2/constructs/template/properties/bool_property.go @@ -0,0 +1,101 @@ +package properties + +import ( + "errors" + "fmt" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" +) + +type ( + BoolProperty struct { + SharedPropertyFields + property.PropertyDetails + } +) + +func (b *BoolProperty) SetProperty(properties construct.Properties, value any) error { + if val, ok := value.(bool); ok { + return properties.SetProperty(b.Path, val) + } else if val, ok := value.(construct.PropertyRef); ok { + return properties.SetProperty(b.Path, val) + } + return fmt.Errorf("invalid bool value %v", value) +} + +func (b *BoolProperty) AppendProperty(properties construct.Properties, value any) error { + return b.SetProperty(properties, value) +} + +func (b *BoolProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(b.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + return properties.RemoveProperty(b.Path, value) +} + +func (b *BoolProperty) Clone() property.Property { + clone := *b + return &clone +} + +func (b *BoolProperty) Details() *property.PropertyDetails { + return &b.PropertyDetails +} + +func (b *BoolProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if b.DefaultValue == nil { + return nil, nil + } + return b.Parse(b.DefaultValue, ctx, data) +} + +func (b *BoolProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + if val, ok := value.(string); ok { + var result bool + err := ctx.ExecuteUnmarshal(val, data, &result) + return result, err + } + if val, ok := value.(bool); ok { + return val, nil + } + + return nil, fmt.Errorf("invalid bool value %v", value) +} + +func (b *BoolProperty) ZeroValue() any { + return false +} + +func (b *BoolProperty) Contains(value any, contains any) bool { + return false +} + +func (b *BoolProperty) Type() string { + return "bool" +} + +func (b *BoolProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if b.Required { + return fmt.Errorf(property.ErrRequiredProperty, b.Path) + } + return nil + } + if _, ok := value.(bool); !ok { + return fmt.Errorf("invalid bool value %v", value) + } + return nil +} + +func (b *BoolProperty) SubProperties() property.PropertyMap { + return nil +} diff --git a/pkg/k2/constructs/template/properties/bool_property_test.go b/pkg/k2/constructs/template/properties/bool_property_test.go new file mode 100644 index 000000000..b7ca7aeac --- /dev/null +++ b/pkg/k2/constructs/template/properties/bool_property_test.go @@ -0,0 +1,122 @@ +package properties + +import ( + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_BoolProppertySetProperty(t *testing.T) { + tests := []struct { + name string + property *BoolProperty + input any + wantError bool + }{ + { + name: "valid bool value", + property: &BoolProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: true, + wantError: false, + }, + { + name: "invalid value type", + property: &BoolProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: 123, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_BoolPropertyZeroValue(t *testing.T) { + assert := assert.New(t) + property := &BoolProperty{} + assert.Equal(false, property.ZeroValue()) +} + +func Test_BoolPropertyDetails(t *testing.T) { + assert := assert.New(t) + property := &BoolProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +func Test_BoolPropertyContains(t *testing.T) { + assert := assert.New(t) + property := &BoolProperty{} + assert.False(property.Contains(nil, nil)) +} + +func Test_BoolPropertyType(t *testing.T) { + assert := assert.New(t) + property := &BoolProperty{} + assert.Equal("bool", property.Type()) +} + +func Test_BoolPropertyValidate(t *testing.T) { + tests := []struct { + name string + property *BoolProperty + properties construct.Properties + value any + wantErr bool + }{ + { + name: "bool property", + property: &BoolProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + value: true, + wantErr: false, + }, + { + name: "invalid value", + property: &BoolProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + value: 1, + wantErr: true, + }, + { + name: "nil value for required property", + property: &BoolProperty{ + PropertyDetails: property.PropertyDetails{Path: "test", Required: true}, + }, + value: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.Validate(tt.properties, tt.value) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_BoolPropertySubProperties(t *testing.T) { + assert := assert.New(t) + property := &BoolProperty{} + assert.Nil(property.SubProperties()) +} diff --git a/pkg/k2/constructs/template/properties/construct_property.go b/pkg/k2/constructs/template/properties/construct_property.go new file mode 100644 index 000000000..15f90467f --- /dev/null +++ b/pkg/k2/constructs/template/properties/construct_property.go @@ -0,0 +1,176 @@ +package properties + +import ( + "errors" + "fmt" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/klothoplatform/klotho/pkg/k2/model" +) + +type ConstructTemplateIdList []property.ConstructType + +func (l ConstructTemplateIdList) MatchesAny(urn model.URN) bool { + var id property.ConstructType + err := id.FromURN(urn) + if err != nil { + return false + } + for _, t := range l { + if t == id { + return true + } + } + return false + +} + +type ( + ConstructProperty struct { + AllowedTypes ConstructTemplateIdList + SharedPropertyFields + property.PropertyDetails + } +) + +func (r *ConstructProperty) SetProperty(properties construct.Properties, value any) error { + if val, ok := value.(model.URN); ok { + return properties.SetProperty(r.Path, val) + } + return fmt.Errorf("invalid construct URN %v", value) +} + +func (r *ConstructProperty) AppendProperty(properties construct.Properties, value any) error { + return r.SetProperty(properties, value) +} + +func (r *ConstructProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(r.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + propId, ok := propVal.(model.URN) + if !ok { + return fmt.Errorf("error attempting to remove construct property: invalid property value %v", propVal) + } + valId, ok := value.(model.URN) + if !ok { + return fmt.Errorf("error attempting to remove construct property: invalid construct value %v", value) + } + if !propId.Equals(valId) { + return fmt.Errorf("error attempting to remove construct property: construct value %v does not match property value %v", value, propVal) + } + return properties.RemoveProperty(r.Path, value) +} + +func (r *ConstructProperty) Details() *property.PropertyDetails { + return &r.PropertyDetails +} +func (r *ConstructProperty) Clone() property.Property { + clone := *r + return &clone +} + +func (r *ConstructProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if r.DefaultValue == nil { + return nil, nil + } + return r.Parse(r.DefaultValue, ctx, data) +} + +func (r *ConstructProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + if val, ok := value.(string); ok { + urn, err := ExecuteUnmarshalAsURN(ctx, val, data) + if err != nil { + return nil, fmt.Errorf("invalid construct URN %v", val) + } + if !urn.IsResource() || urn.Type != "construct" { + return nil, fmt.Errorf("invalid construct URN %v", urn) + } + if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(urn) { + return nil, fmt.Errorf("construct value %v does not match allowed types %s", value, r.AllowedTypes) + } + return urn, err + } + + if val, ok := value.(map[string]interface{}); ok { + id := model.URN{ + AccountID: val["account"].(string), + Project: val["project"].(string), + Environment: val["environment"].(string), + Application: val["application"].(string), + Type: val["type"].(string), + Subtype: val["subtype"].(string), + ParentResourceID: val["parentResourceId"].(string), + ResourceID: val["resourceId"].(string), + } + + if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(id) { + return nil, fmt.Errorf("construct value %v does not match type %s", value, r.AllowedTypes) + } + return id, nil + } + if val, ok := value.(model.URN); ok { + if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(val) { + return nil, fmt.Errorf("construct value %v does not match type %s", value, r.AllowedTypes) + } + return val, nil + } + + return nil, fmt.Errorf("invalid construct value %v", value) +} + +func (r *ConstructProperty) ZeroValue() any { + return model.URN{} +} + +func (r *ConstructProperty) Contains(value any, contains any) bool { + if val, ok := value.(model.URN); ok { + if cont, ok := contains.(model.URN); ok { + return val.Equals(cont) + } + } + return false +} + +func (r *ConstructProperty) Type() string { + if len(r.AllowedTypes) > 0 { + typeString := "" + for i, t := range r.AllowedTypes { + typeString += t.String() + if i < len(r.AllowedTypes)-1 { + typeString += ", " + } + } + return fmt.Sprintf("construct(%s)", typeString) + } + return "construct" +} + +func (r *ConstructProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if r.Required { + return fmt.Errorf(property.ErrRequiredProperty, r.Path) + } + return nil + } + id, ok := value.(model.URN) + if !ok { + return fmt.Errorf("invalid construct URN %v", value) + } + if len(r.AllowedTypes) > 0 && !r.AllowedTypes.MatchesAny(id) { + return fmt.Errorf("value %v does not match allowed types %s", value, r.AllowedTypes) + } + return nil +} + +func (r *ConstructProperty) SubProperties() property.PropertyMap { + return nil +} diff --git a/pkg/k2/constructs/template/properties/float_property.go b/pkg/k2/constructs/template/properties/float_property.go new file mode 100644 index 000000000..8cd74973c --- /dev/null +++ b/pkg/k2/constructs/template/properties/float_property.go @@ -0,0 +1,122 @@ +package properties + +import ( + "errors" + "fmt" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" +) + +type ( + FloatProperty struct { + MinValue *float64 + MaxValue *float64 + SharedPropertyFields + property.PropertyDetails + } +) + +func (f *FloatProperty) SetProperty(properties construct.Properties, value any) error { + switch val := value.(type) { + case float64: + return properties.SetProperty(f.Path, val) + case construct.PropertyRef: + return properties.SetProperty(f.Path, val) + case float32: + return properties.SetProperty(f.Path, float64(val)) + case int: + return properties.SetProperty(f.Path, float64(val)) + default: + return fmt.Errorf("invalid float value %v", value) + } +} + +func (f *FloatProperty) AppendProperty(properties construct.Properties, value any) error { + return f.SetProperty(properties, value) +} + +func (f *FloatProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(f.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + return properties.RemoveProperty(f.Path, value) + +} + +func (f *FloatProperty) Details() *property.PropertyDetails { + return &f.PropertyDetails +} + +func (f *FloatProperty) Clone() property.Property { + clone := *f + return &clone +} + +func (f *FloatProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if f.DefaultValue == nil { + return nil, nil + } + return f.Parse(f.DefaultValue, ctx, data) +} + +func (f *FloatProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + if val, ok := value.(string); ok { + var result float32 + err := ctx.ExecuteUnmarshal(val, data, &result) + return result, err + } + if val, ok := value.(float32); ok { + return val, nil + } + if val, ok := value.(float64); ok { + return val, nil + } + if val, ok := value.(int); ok { + return float64(val), nil + } + return nil, fmt.Errorf("invalid float value %v", value) +} + +func (f *FloatProperty) ZeroValue() any { + return 0.0 +} + +func (f *FloatProperty) Contains(value any, contains any) bool { + return false +} + +func (f *FloatProperty) Type() string { + return "float" +} + +func (f *FloatProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if f.Required { + return fmt.Errorf(property.ErrRequiredProperty, f.Path) + } + return nil + } + floatVal, ok := value.(float64) + if !ok { + return fmt.Errorf("invalid float value %v", value) + } + if f.MinValue != nil && floatVal < *f.MinValue { + return fmt.Errorf("float value %f is less than lower bound %f", value, *f.MinValue) + } + if f.MaxValue != nil && floatVal > *f.MaxValue { + return fmt.Errorf("float value %f is greater than upper bound %f", value, *f.MaxValue) + } + return nil +} + +func (f *FloatProperty) SubProperties() property.PropertyMap { + return nil +} diff --git a/pkg/k2/constructs/template/properties/float_property_test.go b/pkg/k2/constructs/template/properties/float_property_test.go new file mode 100644 index 000000000..0e6b14a6f --- /dev/null +++ b/pkg/k2/constructs/template/properties/float_property_test.go @@ -0,0 +1,251 @@ +package properties + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" +) + +// Testing the SetProperty method for different cases +func Test_FloatProperty_SetProperty(t *testing.T) { + tests := []struct { + name string + property *FloatProperty + input any + wantError bool + }{ + { + name: "valid float64 value", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: float64(3.14), + wantError: false, + }, + { + name: "valid float32 value", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: float32(3.14), + wantError: false, + }, + { + name: "valid int value", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: 42, + wantError: false, + }, + { + name: "invalid string value", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: "invalid_float", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the ZeroValue method +func Test_FloatProperty_ZeroValue(t *testing.T) { + assert := assert.New(t) + property := &FloatProperty{} + assert.Equal(0.0, property.ZeroValue()) +} + +// Testing the Details method +func Test_FloatProperty_Details(t *testing.T) { + assert := assert.New(t) + property := &FloatProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +// Testing the Clone method +func Test_FloatProperty_Clone(t *testing.T) { + property := &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + MinValue: new(float64), + MaxValue: new(float64), + } + clone := property.Clone() + assert.Equal(t, property, clone) +} + +// Testing the AppendProperty method with different cases +func Test_FloatProperty_AppendProperty(t *testing.T) { + tests := []struct { + name string + property *FloatProperty + properties construct.Properties + input any + wantError bool + }{ + { + name: "append float64 value", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: float64(3.14), + wantError: false, + }, + { + name: "append int value", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: 42, + wantError: false, + }, + { + name: "append invalid value", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: "invalid_float", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.AppendProperty(tt.properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the RemoveProperty method +func Test_FloatProperty_RemoveProperty(t *testing.T) { + tests := []struct { + name string + property *FloatProperty + properties construct.Properties + input any + wantError bool + }{ + { + name: "remove existing float property", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{"test": float64(3.14)}, + wantError: false, + }, + { + name: "remove non-existent property", + property: &FloatProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.RemoveProperty(tt.properties, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the Parse method +func Test_FloatProperty_Parse(t *testing.T) { + tests := []struct { + name string + property *FloatProperty + input any + expected any + wantError bool + }{ + { + name: "parse string to float", + property: &FloatProperty{}, + input: "3.14", + expected: float32(3.14), + wantError: false, + }, + { + name: "parse int to float", + property: &FloatProperty{}, + input: 42, + expected: float64(42), + wantError: false, + }, + { + name: "parse float64", + property: &FloatProperty{}, + input: float64(3.14), + expected: float64(3.14), + wantError: false, + }, + { + name: "parse invalid string", + property: &FloatProperty{}, + input: "invalid_float", + expected: nil, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.property.Parse(tt.input, DefaultExecutionContext{}, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// Testing the Contains method +func Test_FloatProperty_Contains(t *testing.T) { + assert := assert.New(t) + property := &FloatProperty{} + assert.False(property.Contains(1.0, 1.0)) +} + +// Testing the Type method +func Test_FloatProperty_Type(t *testing.T) { + assert := assert.New(t) + property := &FloatProperty{} + assert.Equal("float", property.Type()) +} + +// Testing the SubProperties method +func Test_FloatProperty_SubProperties(t *testing.T) { + assert := assert.New(t) + property := &FloatProperty{} + assert.Nil(property.SubProperties()) +} diff --git a/pkg/k2/constructs/template/properties/int_property.go b/pkg/k2/constructs/template/properties/int_property.go new file mode 100644 index 000000000..64c901b8d --- /dev/null +++ b/pkg/k2/constructs/template/properties/int_property.go @@ -0,0 +1,126 @@ +package properties + +import ( + "errors" + "fmt" + "math" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" +) + +type ( + IntProperty struct { + MinValue *int + MaxValue *int + SharedPropertyFields + property.PropertyDetails + } +) + +func (i *IntProperty) SetProperty(properties construct.Properties, value any) error { + if val, ok := value.(int); ok { + return properties.SetProperty(i.Path, val) + } else if val, ok := value.(construct.PropertyRef); ok { + return properties.SetProperty(i.Path, val) + } + return fmt.Errorf("invalid int value %v", value) +} + +func (i *IntProperty) AppendProperty(properties construct.Properties, value any) error { + return i.SetProperty(properties, value) +} + +func (i *IntProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(i.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + return properties.RemoveProperty(i.Path, value) +} + +func (i *IntProperty) Details() *property.PropertyDetails { + return &i.PropertyDetails +} + +func (i *IntProperty) Clone() property.Property { + clone := *i + return &clone +} + +func (i *IntProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if i.DefaultValue == nil { + return nil, nil + } + return i.Parse(i.DefaultValue, ctx, data) +} + +func (i *IntProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + + if val, ok := value.(string); ok { + var result int + err := ctx.ExecuteUnmarshal(val, data, &result) + return result, err + } + if val, ok := value.(int); ok { + return val, nil + } + EPSILON := 0.0000001 + if val, ok := value.(float32); ok { + ival := int(val) + if math.Abs(float64(val)-float64(ival)) > EPSILON { + return 0, fmt.Errorf("cannot convert non-integral float to int: %f", val) + } + return int(val), nil + + } else if val, ok := value.(float64); ok { + ival := int(val) + if math.Abs(val-float64(ival)) > EPSILON { + return 0, fmt.Errorf("cannot convert non-integral float to int: %f", val) + } + return int(val), nil + } + return nil, fmt.Errorf("invalid int value %v", value) +} + +func (i *IntProperty) ZeroValue() any { + return 0 +} + +func (i *IntProperty) Contains(value any, contains any) bool { + return false +} + +func (i *IntProperty) Type() string { + return "int" +} + +func (i *IntProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if i.Required { + return fmt.Errorf(property.ErrRequiredProperty, i.Path) + } + return nil + } + intVal, ok := value.(int) + if !ok { + return fmt.Errorf("invalid int value %v", value) + } + if i.MinValue != nil && intVal < *i.MinValue { + return fmt.Errorf("int value %v is less than lower bound %d", value, *i.MinValue) + } + if i.MaxValue != nil && intVal > *i.MaxValue { + return fmt.Errorf("int value %v is greater than upper bound %d", value, *i.MaxValue) + } + return nil +} + +func (i *IntProperty) SubProperties() property.PropertyMap { + return nil +} diff --git a/pkg/k2/constructs/template/properties/int_property_test.go b/pkg/k2/constructs/template/properties/int_property_test.go new file mode 100644 index 000000000..b1e7513e9 --- /dev/null +++ b/pkg/k2/constructs/template/properties/int_property_test.go @@ -0,0 +1,243 @@ +package properties + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" +) + +// Testing the SetProperty method for different cases +func Test_IntProperty_SetProperty(t *testing.T) { + tests := []struct { + name string + property *IntProperty + input any + wantError bool + }{ + { + name: "valid int value", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: 42, + wantError: false, + }, + { + name: "invalid float value", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: float32(42.0), + wantError: true, + }, + { + name: "invalid string value", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: "invalid_int", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the ZeroValue method +func Test_IntProperty_ZeroValue(t *testing.T) { + assert := assert.New(t) + property := &IntProperty{} + assert.Equal(0, property.ZeroValue()) +} + +// Testing the Details method +func Test_IntProperty_Details(t *testing.T) { + assert := assert.New(t) + property := &IntProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +// Testing the Clone method +func Test_IntProperty_Clone(t *testing.T) { + property := &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + MinValue: new(int), + MaxValue: new(int), + } + clone := property.Clone() + assert.Equal(t, property, clone) +} + +// Testing the AppendProperty method with different cases +func Test_IntProperty_AppendProperty(t *testing.T) { + tests := []struct { + name string + property *IntProperty + properties construct.Properties + input any + wantError bool + }{ + { + name: "append int value", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: 42, + wantError: false, + }, + { + name: "append invalid float value", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: float32(42.0), + wantError: true, + }, + { + name: "append invalid string value", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: "invalid_int", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.AppendProperty(tt.properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the RemoveProperty method +func Test_IntProperty_RemoveProperty(t *testing.T) { + tests := []struct { + name string + property *IntProperty + properties construct.Properties + wantError bool + }{ + { + name: "remove existing int property", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{"test": 42}, + }, + { + name: "remove non-existent property", + property: &IntProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert := assert.New(t) + err := test.property.RemoveProperty(test.properties, nil) + if test.wantError { + assert.Error(err) + return + } else { + assert.NoError(err) + } + assert.NotContains(test.properties, test.property.Path) + }) + } +} + +// Testing the Parse method +func Test_IntProperty_Parse(t *testing.T) { + tests := []struct { + name string + property *IntProperty + input any + expected any + wantError bool + }{ + { + name: "parse string to int", + property: &IntProperty{}, + input: "42", + expected: 42, + wantError: false, + }, + { + name: "parse int", + property: &IntProperty{}, + input: 42, + expected: 42, + wantError: false, + }, + { + name: "parse float", + property: &IntProperty{}, + input: float32(42.0), + expected: 42, + wantError: false, + }, + { + name: "parse invalid string", + property: &IntProperty{}, + input: "invalid_int", + expected: nil, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.property.Parse(tt.input, DefaultExecutionContext{}, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// Testing the Contains method +func Test_IntProperty_Contains(t *testing.T) { + assert := assert.New(t) + property := &IntProperty{} + assert.False(property.Contains(1, 1)) +} + +// Testing the Type method +func Test_IntProperty_Type(t *testing.T) { + assert := assert.New(t) + property := &IntProperty{} + assert.Equal("int", property.Type()) +} + +// Testing the SubProperties method +func Test_IntProperty_SubProperties(t *testing.T) { + assert := assert.New(t) + property := &IntProperty{} + assert.Nil(property.SubProperties()) +} diff --git a/pkg/k2/constructs/template/properties/key_value_list.go b/pkg/k2/constructs/template/properties/key_value_list.go new file mode 100644 index 000000000..696d4919a --- /dev/null +++ b/pkg/k2/constructs/template/properties/key_value_list.go @@ -0,0 +1,229 @@ +package properties + +import ( + "errors" + "fmt" + "reflect" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" +) + +type ( + KeyValueListProperty struct { + MinLength *int + MaxLength *int + KeyProperty property.Property + ValueProperty property.Property + SharedPropertyFields + property.PropertyDetails + } + + KeyValuePair struct { + Key any `json:"key"` + Value any `json:"value"` + } +) + +func (kvl *KeyValueListProperty) SetProperty(properties construct.Properties, value any) error { + list, err := kvl.mapToList(value) + if err != nil { + return err + } + return properties.SetProperty(kvl.Path, list) +} + +func (kvl *KeyValueListProperty) AppendProperty(properties construct.Properties, value any) error { + list, err := kvl.mapToList(value) + if err != nil { + return err + } + propVal, err := properties.GetProperty(kvl.Path) + if err != nil && !errors.Is(err, construct.ErrPropertyDoesNotExist) { + return err + } + if propVal == nil { + return properties.SetProperty(kvl.Path, list) + } + existingList, ok := propVal.([]any) + if !ok { + return fmt.Errorf("invalid existing property value %v", propVal) + } + return properties.SetProperty(kvl.Path, append(existingList, list...)) +} + +func (kvl *KeyValueListProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(kvl.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + existingList, ok := propVal.([]any) + if !ok { + return fmt.Errorf("invalid existing property value %v", propVal) + } + removeList, err := kvl.mapToList(value) + if err != nil { + return err + } + filteredList := make([]any, 0, len(existingList)) + for _, item := range existingList { + if !kvl.containsKeyValuePair(removeList, item) { + filteredList = append(filteredList, item) + } + } + return properties.SetProperty(kvl.Path, filteredList) +} + +func (kvl *KeyValueListProperty) Details() *property.PropertyDetails { + return &kvl.PropertyDetails +} + +func (kvl *KeyValueListProperty) Clone() property.Property { + clone := *kvl + if kvl.KeyProperty != nil { + clone.KeyProperty = kvl.KeyProperty.Clone() + } + if kvl.ValueProperty != nil { + clone.ValueProperty = kvl.ValueProperty.Clone() + } + return &clone +} + +func (kvl *KeyValueListProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if kvl.DefaultValue == nil { + return nil, nil + } + return kvl.Parse(kvl.DefaultValue, ctx, data) +} + +func (kvl *KeyValueListProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + list, err := kvl.mapToList(value) + if err != nil { + return nil, err + } + result := make([]any, 0, len(list)) + for _, item := range list { + pair, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid key-value pair %v", item) + } + + key, err := kvl.KeyProperty.Parse(pair[kvl.KeyPropertyName()], ctx, data) + if err != nil { + return nil, fmt.Errorf("error parsing key: %w", err) + } + value, err := kvl.ValueProperty.Parse(pair[kvl.ValuePropertyName()], ctx, data) + if err != nil { + return nil, fmt.Errorf("error parsing value: %w", err) + } + result = append(result, map[string]any{ + kvl.KeyPropertyName(): key, + kvl.ValuePropertyName(): value, + }) + } + return result, nil +} + +func (kvl *KeyValueListProperty) KeyPropertyName() string { + return kvl.KeyProperty.Details().Name +} + +func (kvl *KeyValueListProperty) ValuePropertyName() string { + return kvl.ValueProperty.Details().Name +} + +func (kvl *KeyValueListProperty) ZeroValue() any { + return nil +} + +func (kvl *KeyValueListProperty) Contains(value any, contains any) bool { + list, err := kvl.mapToList(value) + if err != nil { + return false + } + containsList, err := kvl.mapToList(contains) + if err != nil { + return false + } + for _, item := range containsList { + if kvl.containsKeyValuePair(list, item) { + return true + } + } + return false +} + +func (kvl *KeyValueListProperty) Type() string { + return fmt.Sprintf("keyvaluelist(%s,%s)", kvl.KeyProperty.Type(), kvl.ValueProperty.Type()) +} + +func (kvl *KeyValueListProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if kvl.Required { + return fmt.Errorf(property.ErrRequiredProperty, kvl.Path) + } + return nil + } + list, err := kvl.mapToList(value) + if err != nil { + return err + } + if kvl.MinLength != nil && len(list) < *kvl.MinLength { + return fmt.Errorf("list value %v is too short. min length is %d", value, *kvl.MinLength) + } + if kvl.MaxLength != nil && len(list) > *kvl.MaxLength { + return fmt.Errorf("list value %v is too long. max length is %d", value, *kvl.MaxLength) + } + var errs error + for _, item := range list { + pair, ok := item.(map[string]any) + if !ok { + errs = errors.Join(errs, fmt.Errorf("invalid key-value pair %v", item)) + continue + } + if err := kvl.KeyProperty.Validate(properties, pair[kvl.KeyPropertyName()]); err != nil { + errs = errors.Join(errs, fmt.Errorf("invalid key %v: %w", pair[kvl.KeyPropertyName()], err)) + } + if err := kvl.ValueProperty.Validate(properties, pair[kvl.ValuePropertyName()]); err != nil { + errs = errors.Join(errs, fmt.Errorf("invalid value %v: %w", pair[kvl.ValuePropertyName()], err)) + } + } + return errs +} + +func (kvl *KeyValueListProperty) SubProperties() property.PropertyMap { + return nil +} + +func (kvl *KeyValueListProperty) mapToList(value any) ([]any, error) { + switch v := value.(type) { + case []any: + return v, nil + case map[string]any: + result := make([]any, 0, len(v)) + for key, val := range v { + result = append(result, map[string]any{ + kvl.KeyPropertyName(): key, + kvl.ValuePropertyName(): val, + }) + } + return result, nil + default: + return nil, fmt.Errorf("invalid input type for KeyValueListProperty: %T", value) + } +} + +func (kvl *KeyValueListProperty) containsKeyValuePair(list []any, item any) bool { + for _, listItem := range list { + if reflect.DeepEqual(listItem, item) { + return true + } + } + return false +} diff --git a/pkg/k2/constructs/template/properties/key_value_list_test.go b/pkg/k2/constructs/template/properties/key_value_list_test.go new file mode 100644 index 000000000..d85b0b05b --- /dev/null +++ b/pkg/k2/constructs/template/properties/key_value_list_test.go @@ -0,0 +1,379 @@ +package properties + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" +) + +func Test_KeyValueListProperty_SetProperty(t *testing.T) { + tests := []struct { + name string + property *KeyValueListProperty + input any + expected []any + wantError bool + }{ + { + name: "set property with map input", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + input: map[string]any{"key1": "value1", "key2": "value2"}, + expected: []any{ + map[string]any{"key": "key1", "value": "value1"}, + map[string]any{"key": "key2", "value": "value2"}, + }, + wantError: false, + }, + { + name: "set property with list input", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + input: []any{ + map[string]any{"key": "key1", "value": "value1"}, + map[string]any{"key": "key2", "value": "value2"}, + }, + expected: []any{ + map[string]any{"key": "key1", "value": "value1"}, + map[string]any{"key": "key2", "value": "value2"}, + }, + wantError: false, + }, + { + name: "set property with invalid input", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + input: "invalid input", + expected: nil, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + result, _ := properties.GetProperty(tt.property.Path) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_KeyValueListProperty_AppendProperty(t *testing.T) { + tests := []struct { + name string + property *KeyValueListProperty + initial []any + input any + expected []any + wantError bool + }{ + { + name: "append to existing list", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + initial: []any{ + map[string]any{"key": "key1", "value": "value1"}, + }, + input: map[string]any{"key2": "value2"}, + expected: []any{ + map[string]any{"key": "key1", "value": "value1"}, + map[string]any{"key": "key2", "value": "value2"}, + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + properties.SetProperty(tt.property.Path, tt.initial) + err := tt.property.AppendProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + result, _ := properties.GetProperty(tt.property.Path) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_KeyValueListProperty_RemoveProperty(t *testing.T) { + tests := []struct { + name string + property *KeyValueListProperty + initial []any + input any + expected []any + wantError bool + }{ + { + name: "remove existing key-value pair", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + initial: []any{ + map[string]any{"key": "key1", "value": "value1"}, + map[string]any{"key": "key2", "value": "value2"}, + }, + input: map[string]any{"key1": "value1"}, + expected: []any{ + map[string]any{"key": "key2", "value": "value2"}, + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + properties.SetProperty(tt.property.Path, tt.initial) + err := tt.property.RemoveProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + result, _ := properties.GetProperty(tt.property.Path) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_KeyValueListProperty_GetDefaultValue(t *testing.T) { + tests := []struct { + name string + property *KeyValueListProperty + expectedValue any + wantError bool + }{ + { + name: "return default value", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{ + Path: "test", + }, + SharedPropertyFields: SharedPropertyFields{ + DefaultValue: map[string]any{"defaultKey": "defaultValue"}, + }, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + expectedValue: []any{map[string]any{"key": "defaultKey", "value": "defaultValue"}}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := DefaultExecutionContext{} + result, err := tt.property.GetDefaultValue(ctx, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedValue, result) + } + }) + } +} + +func Test_KeyValueListProperty_Parse(t *testing.T) { + tests := []struct { + name string + property *KeyValueListProperty + input any + expectedValue any + wantError bool + }{ + { + name: "parse valid input", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &IntProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + input: []any{ + map[string]any{"key": "key1", "value": "42"}, + map[string]any{"key": "key2", "value": "24"}, + }, + expectedValue: []any{ + map[string]any{"key": "key1", "value": 42}, + map[string]any{"key": "key2", "value": 24}, + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := DefaultExecutionContext{} + result, err := tt.property.Parse(tt.input, ctx, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedValue, result) + } + }) + } +} + +func Test_KeyValueListProperty_Contains(t *testing.T) { + tests := []struct { + name string + property *KeyValueListProperty + value any + contains any + expected bool + }{ + { + name: "contains key-value pair", + property: &KeyValueListProperty{ + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + value: []any{ + map[string]any{"key": "key1", "value": "value1"}, + map[string]any{"key": "key2", "value": "value2"}, + }, + contains: map[string]any{"key1": "value1"}, + expected: true, + }, + { + name: "does not contain key-value pair", + property: &KeyValueListProperty{ + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + value: []any{ + map[string]any{"key": "key1", "value": "value1"}, + map[string]any{"key": "key2", "value": "value2"}, + }, + contains: map[string]any{"key3": "value3"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.property.Contains(tt.value, tt.contains) + assert.Equal(t, tt.expected, result) + }) + } +} + +func Test_KeyValueListProperty_Type(t *testing.T) { + property := &KeyValueListProperty{ + KeyProperty: &StringProperty{}, + ValueProperty: &IntProperty{}, + } + assert.Equal(t, "keyvaluelist(string,int)", property.Type()) +} + +func Test_KeyValueListProperty_Validate(t *testing.T) { + tests := []struct { + name string + property *KeyValueListProperty + value any + wantError bool + }{ + { + name: "valid input", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &IntProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + value: []any{ + map[string]any{"key": "key1", "value": 42}, + map[string]any{"key": "key2", "value": 24}, + }, + wantError: false, + }, + { + name: "invalid key", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &IntProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + value: []any{ + map[string]any{"key": 42, "value": 42}, + }, + wantError: true, + }, + { + name: "invalid value", + property: &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &IntProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + }, + value: []any{ + map[string]any{"key": "key1", "value": "not an int"}, + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.Validate(properties, tt.value) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_KeyValueListProperty_Clone(t *testing.T) { + original := &KeyValueListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{PropertyDetails: property.PropertyDetails{Name: "key"}}, + ValueProperty: &IntProperty{PropertyDetails: property.PropertyDetails{Name: "value"}}, + MinLength: ptr(1), + MaxLength: ptr(10), + } + + clone := original.Clone().(*KeyValueListProperty) + + assert.Equal(t, original.PropertyDetails, clone.PropertyDetails) + assert.Equal(t, original.KeyProperty.Type(), clone.KeyProperty.Type()) + assert.Equal(t, original.ValueProperty.Type(), clone.ValueProperty.Type()) + assert.Equal(t, *original.MinLength, *clone.MinLength) + assert.Equal(t, *original.MaxLength, *clone.MaxLength) + assert.NotSame(t, original.KeyProperty, clone.KeyProperty) + assert.NotSame(t, original.ValueProperty, clone.ValueProperty) +} + +func Test_KeyValueListProperty_SubProperties(t *testing.T) { + property := &KeyValueListProperty{} + assert.Nil(t, property.SubProperties()) +} diff --git a/pkg/k2/constructs/template/properties/list_property.go b/pkg/k2/constructs/template/properties/list_property.go new file mode 100644 index 000000000..0d40e4425 --- /dev/null +++ b/pkg/k2/constructs/template/properties/list_property.go @@ -0,0 +1,264 @@ +package properties + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + + "github.com/klothoplatform/klotho/pkg/collectionutil" +) + +type ( + ListProperty struct { + MinLength *int + MaxLength *int + ItemProperty property.Property + Properties property.PropertyMap + SharedPropertyFields + property.PropertyDetails + } +) + +func (l *ListProperty) SetProperty(properties construct.Properties, value any) error { + if val, ok := value.([]any); ok { + return properties.SetProperty(l.Path, val) + } + return fmt.Errorf("invalid list value %v", value) +} + +func (l *ListProperty) AppendProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(l.Path) + if err != nil && !errors.Is(err, construct.ErrPropertyDoesNotExist) { + return err + } + if propVal == nil { + err := l.SetProperty(properties, []any{}) + if err != nil { + return err + } + } + if l.ItemProperty != nil && !strings.HasPrefix(l.ItemProperty.Type(), "list") { + if reflect.ValueOf(value).Kind() == reflect.Slice || reflect.ValueOf(value).Kind() == reflect.Array { + var errs error + for i := 0; i < reflect.ValueOf(value).Len(); i++ { + err := properties.AppendProperty(l.Path, reflect.ValueOf(value).Index(i).Interface()) + if err != nil { + errs = errors.Join(errs, err) + } + } + return errs + } + } + return properties.AppendProperty(l.Path, value) +} + +func (l *ListProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(l.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + if l.ItemProperty != nil && !strings.HasPrefix(l.ItemProperty.Type(), "list") { + if reflect.ValueOf(value).Kind() == reflect.Slice || reflect.ValueOf(value).Kind() == reflect.Array { + var errs error + for i := 0; i < reflect.ValueOf(value).Len(); i++ { + err := properties.RemoveProperty(l.Path, reflect.ValueOf(value).Index(i).Interface()) + if err != nil { + errs = errors.Join(errs, err) + } + } + return errs + } + } + return properties.RemoveProperty(l.Path, value) +} + +func (l *ListProperty) Details() *property.PropertyDetails { + return &l.PropertyDetails +} + +func (l *ListProperty) Clone() property.Property { + var itemProp property.Property + if l.ItemProperty != nil { + itemProp = l.ItemProperty.Clone() + } + var props property.PropertyMap + if l.Properties != nil { + props = l.Properties.Clone() + } + clone := *l + clone.ItemProperty = itemProp + clone.Properties = props + return &clone +} + +func (list *ListProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if list.DefaultValue == nil { + return nil, nil + } + return list.Parse(list.DefaultValue, ctx, data) +} + +func (list *ListProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + + var result []any + val, ok := value.([]any) + if !ok { + // before we fail, check to see if the entire value is a template + if strVal, ok := value.(string); ok { + var result []any + err := ctx.ExecuteUnmarshal(strVal, data, &result) + if err != nil { + return nil, fmt.Errorf("invalid list value %v: %w", value, err) + } + val = result + } else { + return nil, fmt.Errorf("invalid list value %v", value) + } + } + + for _, v := range val { + if len(list.Properties) != 0 { + m := MapProperty{Properties: list.Properties} + val, err := m.Parse(v, ctx, data) + if err != nil { + return nil, err + } + result = append(result, val) + } else { + val, err := list.ItemProperty.Parse(v, ctx, data) + if err != nil { + return nil, err + } + result = append(result, val) + } + } + return result, nil +} + +func (l *ListProperty) ZeroValue() any { + return nil +} + +func (l *ListProperty) Contains(value any, contains any) bool { + list, ok := value.([]any) + if !ok { + return false + } + containsList, ok := contains.([]any) + if !ok { + return collectionutil.Contains(list, contains) + } + for _, v := range list { + for _, cv := range containsList { + if reflect.DeepEqual(v, cv) { + return true + } + } + } + return false +} + +func (l *ListProperty) Type() string { + if l.ItemProperty != nil { + return fmt.Sprintf("list(%s)", l.ItemProperty.Type()) + } + return "list" +} + +func (l *ListProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if l.Required { + return fmt.Errorf(property.ErrRequiredProperty, l.Path) + } + return nil + } + + listVal, ok := value.([]any) + if !ok { + return fmt.Errorf("invalid list value %v", value) + } + if l.MinLength != nil { + if len(listVal) < *l.MinLength { + return fmt.Errorf("list value %v is too short. min length is %d", value, *l.MinLength) + } + } + if l.MaxLength != nil { + if len(listVal) > *l.MaxLength { + return fmt.Errorf("list value %v is too long. max length is %d", value, *l.MaxLength) + } + } + + validList := make([]any, len(listVal)) + var errs error + hasSanitized := false + for i, v := range listVal { + if l.ItemProperty != nil { + err := l.ItemProperty.Validate(properties, v) + if err != nil { + var sanitizeErr *property.SanitizeError + if errors.As(err, &sanitizeErr) { + validList[i] = sanitizeErr.Sanitized + hasSanitized = true + } else { + errs = errors.Join(errs, err) + } + } else { + validList[i] = v + } + } else { + vmap, ok := v.(map[string]any) + if !ok { + return fmt.Errorf("invalid value for list index %d in sub properties validation: expected map[string]any got %T", i, v) + } + validIndex := make(map[string]any) + for _, prop := range l.SubProperties() { + val, ok := vmap[prop.Details().Name] + if !ok { + continue + } + err := prop.Validate(properties, val) + if err != nil { + var sanitizeErr *property.SanitizeError + if errors.As(err, &sanitizeErr) { + validIndex[prop.Details().Name] = sanitizeErr.Sanitized + hasSanitized = true + } else { + errs = errors.Join(errs, err) + } + } else { + validIndex[prop.Details().Name] = val + } + } + validList[i] = validIndex + } + } + if errs != nil { + return errs + } + if hasSanitized { + return &property.SanitizeError{ + Input: listVal, + Sanitized: validList, + } + } + + return nil +} + +func (l *ListProperty) SubProperties() property.PropertyMap { + return l.Properties +} + +func (l *ListProperty) Item() property.Property { + return l.ItemProperty +} diff --git a/pkg/k2/constructs/template/properties/list_property_test.go b/pkg/k2/constructs/template/properties/list_property_test.go new file mode 100644 index 000000000..e6cfe470e --- /dev/null +++ b/pkg/k2/constructs/template/properties/list_property_test.go @@ -0,0 +1,283 @@ +package properties + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" +) + +// Testing the SetProperty method for different cases +func Test_ListProperty_SetProperty(t *testing.T) { + tests := []struct { + name string + property *ListProperty + input any + wantError bool + }{ + { + name: "valid list value", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: []any{"item1", "item2"}, + wantError: false, + }, + { + name: "invalid map value", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: map[string]any{"key": "value"}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the ZeroValue method +func Test_ListProperty_ZeroValue(t *testing.T) { + assert := assert.New(t) + property := &ListProperty{} + assert.Nil(property.ZeroValue()) +} + +// Testing the Details method +func Test_ListProperty_Details(t *testing.T) { + assert := assert.New(t) + property := &ListProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +// Testing the Clone method +func Test_ListProperty_Clone(t *testing.T) { + property := &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + ItemProperty: &StringProperty{}, + } + clone := property.Clone() + assert.Equal(t, property, clone) +} + +// Testing the AppendProperty method with different cases +func Test_ListProperty_AppendProperty(t *testing.T) { + tests := []struct { + name string + property *ListProperty + properties construct.Properties + input any + wantError bool + expect any + }{ + { + name: "append valid list value", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: []any{"item1", "item2"}, + expect: []any{"item1", "item2"}, + }, + { + name: "append string value", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: "item1", + expect: []any{"item1"}, + }, + { + name: "append to existing list", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{"test": []any{"item1"}}, + input: "item2", + expect: []any{"item1", "item2"}, + }, + { + // This test documents existing non-ideal behavior + name: "append allows invalid values", + property: &ListProperty{ + ItemProperty: &StringProperty{}, + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: map[string]any{"key": "value"}, + expect: []any{map[string]any{"key": "value"}}, + }, + } + + for _, tt := range tests { + assert := assert.New(t) + t.Run(tt.name, func(t *testing.T) { + err := tt.property.AppendProperty(tt.properties, tt.input) + if tt.wantError { + assert.Error(err) + return + } + assert.NoError(err) + assert.Equal(tt.expect, tt.properties[tt.property.Path]) + }) + } +} + +// Testing the RemoveProperty method +func Test_ListProperty_RemoveProperty(t *testing.T) { + tests := []struct { + name string + property *ListProperty + properties construct.Properties + input any + wantError bool + }{ + { + name: "remove existing list property", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{"test": []any{"item1", "item2"}}, + input: "item1", + wantError: false, + }, + { + name: "remove non-existent property", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: "item1", + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.RemoveProperty(tt.properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Testing the Parse method +func Test_ListProperty_Parse(t *testing.T) { + tests := []struct { + name string + property *ListProperty + input any + expected any + wantError bool + }{ + { + name: "parse string to list", + property: &ListProperty{ + ItemProperty: &StringProperty{}, + }, + input: "[\"item1\", \"item2\"]", + expected: []any{"item1", "item2"}, + wantError: false, + }, + { + name: "parse valid list", + property: &ListProperty{ + ItemProperty: &StringProperty{}, + }, + input: []any{"item1", "item2"}, + expected: []any{"item1", "item2"}, + wantError: false, + }, + { + name: "parse invalid map", + property: &ListProperty{ + ItemProperty: &StringProperty{}, + }, + input: map[string]any{"key": "value"}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.property.Parse(tt.input, DefaultExecutionContext{}, nil) + if tt.wantError { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +// Testing the Contains method +func Test_ListProperty_Contains(t *testing.T) { + tests := []struct { + name string + property *ListProperty + value any + expected bool + }{ + { + name: "list contains value", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + ItemProperty: &StringProperty{}, + }, + value: []any{"test"}, + expected: true, + }, + { + name: "list does not contain value", + property: &ListProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + ItemProperty: &StringProperty{}, + }, + value: []any{"other"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.property.Contains(tt.value, "test") + assert.Equal(t, tt.expected, result) + }) + } +} + +// Testing the Type method +func Test_ListProperty_Type(t *testing.T) { + assert := assert.New(t) + property := &ListProperty{} + assert.Equal("list", property.Type()) + property2 := &ListProperty{ + ItemProperty: &StringProperty{}, + } + assert.Equal("list(string)", property2.Type()) +} + +// Testing the SubProperties method +func Test_ListProperty_SubProperties(t *testing.T) { + assert := assert.New(t) + property := &ListProperty{ + Properties: make(property.PropertyMap), + } + assert.NotNil(property.SubProperties()) +} diff --git a/pkg/k2/constructs/template/properties/map_property.go b/pkg/k2/constructs/template/properties/map_property.go new file mode 100644 index 000000000..7bf101b1f --- /dev/null +++ b/pkg/k2/constructs/template/properties/map_property.go @@ -0,0 +1,263 @@ +package properties + +import ( + "errors" + "fmt" + "reflect" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" +) + +type ( + MapProperty struct { + MinLength *int + MaxLength *int + KeyProperty property.Property + ValueProperty property.Property + Properties property.PropertyMap + SharedPropertyFields + property.PropertyDetails + } +) + +func (m *MapProperty) SetProperty(properties construct.Properties, value any) error { + if val, ok := value.(map[string]any); ok { + return properties.SetProperty(m.Path, val) + } + return fmt.Errorf("invalid properties value %v", value) +} + +func (m *MapProperty) AppendProperty(properties construct.Properties, value any) error { + return properties.AppendProperty(m.Path, value) +} + +func (m *MapProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(m.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + propMap, ok := propVal.(map[string]any) + if !ok { + return fmt.Errorf("error attempting to remove map property: invalid property value %v", propVal) + } + if val, ok := value.(map[string]any); ok { + for k, v := range val { + if val, found := propMap[k]; found && reflect.DeepEqual(val, v) { + delete(propMap, k) + } + } + return properties.SetProperty(m.Path, propMap) + } + return properties.RemoveProperty(m.Path, value) +} + +func (m *MapProperty) Details() *property.PropertyDetails { + return &m.PropertyDetails +} + +func (m *MapProperty) Clone() property.Property { + var keyProp property.Property + if m.KeyProperty != nil { + keyProp = m.KeyProperty.Clone() + } + var valProp property.Property + if m.ValueProperty != nil { + valProp = m.ValueProperty.Clone() + } + var props property.PropertyMap + if m.Properties != nil { + props = m.Properties.Clone() + } + clone := *m + clone.KeyProperty = keyProp + clone.ValueProperty = valProp + clone.Properties = props + return &clone +} + +func (m *MapProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if m.DefaultValue == nil { + return nil, nil + } + return m.Parse(m.DefaultValue, ctx, data) +} + +func (m *MapProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + result := map[string]any{} + + mapVal, ok := value.(map[string]any) + if !ok { + // before we fail, check to see if the entire value is a template + if strVal, ok := value.(string); ok { + err := ctx.ExecuteUnmarshal(strVal, data, &result) + return result, err + } + mapVal, ok = value.(construct.Properties) + if !ok { + return nil, fmt.Errorf("invalid map value %v", value) + } + } + // If we are an object with sub properties then we know that we need to get the type of our sub properties to determine how we are parsed into a value + if len(m.Properties) != 0 { + var errs error + for key, prop := range m.Properties { + if _, found := mapVal[key]; found { + val, err := prop.Parse(mapVal[key], ctx, data) + if err != nil { + errs = errors.Join(errs, fmt.Errorf("unable to parse value for sub property %s: %w", key, err)) + continue + } + result[key] = val + } else { + val, err := prop.GetDefaultValue(ctx, data) + if err != nil { + errs = errors.Join(errs, fmt.Errorf("unable to get default value for sub property %s: %w", key, err)) + continue + } + if val == nil { + continue + } + result[key] = val + } + } + } + + if m.KeyProperty == nil || m.ValueProperty == nil { + return result, nil + } + + // Else we are a set type of map and can just loop over the values + for key, v := range mapVal { + keyVal, err := m.KeyProperty.Parse(key, ctx, data) + if err != nil { + return nil, err + } + val, err := m.ValueProperty.Parse(v, ctx, data) + if err != nil { + return nil, err + } + switch keyVal := keyVal.(type) { + case string: + result[keyVal] = val + //case constructs.ConstructId: + // result[keyVal.String()] = val + //case construct.PropertyRef: + // result[keyVal.String()] = val + default: + return nil, fmt.Errorf("invalid key type for map property type %s", keyVal) + } + } + return result, nil +} + +func (m *MapProperty) ZeroValue() any { + return nil +} + +func (m *MapProperty) Contains(value any, contains any) bool { + mapVal, ok := value.(map[string]any) + if !ok { + return false + } + containsMap, ok := contains.(map[string]any) + if !ok { + return false + } + for k, v := range containsMap { + if val, found := mapVal[k]; found || reflect.DeepEqual(val, v) { + return true + } + } + for _, v := range mapVal { + for _, cv := range containsMap { + if reflect.DeepEqual(v, cv) { + return true + } + } + } + return false +} + +func (m *MapProperty) Type() string { + if m.KeyProperty != nil && m.ValueProperty != nil { + return fmt.Sprintf("map(%s,%s)", m.KeyProperty.Type(), m.ValueProperty.Type()) + } + return "map" +} + +func (m *MapProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if m.Required { + return fmt.Errorf(property.ErrRequiredProperty, m.Path) + } + return nil + } + mapVal, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("invalid map value %v", value) + } + if m.MinLength != nil { + if len(mapVal) < *m.MinLength { + return fmt.Errorf("map value %v is too short. min length is %d", value, *m.MinLength) + } + } + if m.MaxLength != nil { + if len(mapVal) > *m.MaxLength { + return fmt.Errorf("map value %v is too long. max length is %d", value, *m.MaxLength) + } + } + var errs error + hasSanitized := false + validMap := make(map[string]any) + // Only validate values if it's a primitive map, otherwise let the sub properties handle their own validation + for k, v := range mapVal { + if m.KeyProperty != nil { + var sanitizeErr *property.SanitizeError + if err := m.KeyProperty.Validate(properties, k); errors.As(err, &sanitizeErr) { + k = sanitizeErr.Sanitized.(string) + hasSanitized = true + } else if err != nil { + errs = errors.Join(errs, fmt.Errorf("invalid key %v for map property type %s: %w", k, m.KeyProperty.Type(), err)) + } + } + if m.ValueProperty != nil { + var sanitizeErr *property.SanitizeError + if err := m.ValueProperty.Validate(properties, v); errors.As(err, &sanitizeErr) { + v = sanitizeErr.Sanitized + hasSanitized = true + } else if err != nil { + errs = errors.Join(errs, fmt.Errorf("invalid value %v for map property type %s: %w", v, m.ValueProperty.Type(), err)) + } + } + validMap[k] = v + } + if errs != nil { + return errs + } + if hasSanitized { + return &property.SanitizeError{ + Input: mapVal, + Sanitized: validMap, + } + } + return nil +} + +func (m *MapProperty) SubProperties() property.PropertyMap { + return m.Properties +} + +func (m *MapProperty) Key() property.Property { + return m.KeyProperty +} + +func (m *MapProperty) Value() property.Property { + return m.ValueProperty +} diff --git a/pkg/k2/constructs/template/properties/map_propery_test.go b/pkg/k2/constructs/template/properties/map_propery_test.go new file mode 100644 index 000000000..2f3a81c3f --- /dev/null +++ b/pkg/k2/constructs/template/properties/map_propery_test.go @@ -0,0 +1,359 @@ +package properties + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" +) + +func Test_MapProperty_SetProperty(t *testing.T) { + tests := []struct { + name string + property *MapProperty + input any + wantError bool + }{ + { + name: "valid map value", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: map[string]any{"key1": "value1", "key2": "value2"}, + wantError: false, + }, + { + name: "invalid value type", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: "invalid_value", + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_MapProperty_ZeroValue(t *testing.T) { + assert := assert.New(t) + property := &MapProperty{} + assert.Equal(nil, property.ZeroValue()) +} + +func Test_MapProperty_Details(t *testing.T) { + assert := assert.New(t) + property := &MapProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +func Test_MapProperty_Clone(t *testing.T) { + property := &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + MinLength: new(int), + MaxLength: new(int), + } + clone := property.Clone() + assert.Equal(t, property, clone) +} + +func Test_MapProperty_Validate(t *testing.T) { + tests := []struct { + name string + property *MapProperty + value any + wantErr bool + minLength int + maxLength int + }{ + { + name: "valid map length within range", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + MaxLength: ptr(2), + }, + value: map[string]any{"key1": "value1"}, + wantErr: false, + }, + { + name: "map length too short", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + MinLength: ptr(1), + }, + value: map[string]any{}, + wantErr: true, + }, + { + name: "map length too long", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + MaxLength: ptr(1), + }, + value: map[string]any{"key1": "value1", "key2": "value2", "key3": "value3"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.Validate(properties, tt.value) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_MapProperty_SubProperties(t *testing.T) { + assert := assert.New(t) + property := &MapProperty{} + assert.Nil(property.SubProperties()) +} + +func Test_MapProperty_AppendProperty(t *testing.T) { + tests := []struct { + name string + property *MapProperty + initial map[string]any + input any + expected map[string]any + wantError bool + }{ + { + name: "append to existing map", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + initial: map[string]any{"existing": "value"}, + input: map[string]any{"new": "appended"}, + expected: map[string]any{"existing": "value", "new": "appended"}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + properties.SetProperty("test", tt.initial) + err := tt.property.AppendProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + result, _ := properties.GetProperty("test") + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_MapProperty_RemoveProperty(t *testing.T) { + tests := []struct { + name string + property *MapProperty + initial map[string]any + input any + expected map[string]any + wantError bool + }{ + { + name: "remove existing key", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + initial: map[string]any{"key1": "value1", "key2": "value2"}, + input: map[string]any{"key1": "value1"}, + expected: map[string]any{"key2": "value2"}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + properties.SetProperty("test", tt.initial) + err := tt.property.RemoveProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + result, _ := properties.GetProperty("test") + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func Test_MapProperty_GetDefaultValue(t *testing.T) { + tests := []struct { + name string + property *MapProperty + expectedValue any + wantError bool + }{ + { + name: "return default value", + property: &MapProperty{ + Properties: map[string]property.Property{ + "default": &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: ".default", Name: "default"}, + }}, + SharedPropertyFields: SharedPropertyFields{DefaultValue: map[string]any{"default": "value"}}, + }, + expectedValue: map[string]any{"default": "value"}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := DefaultExecutionContext{} + result, err := tt.property.GetDefaultValue(ctx, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedValue, result) + } + }) + } +} + +func Test_MapProperty_Parse(t *testing.T) { + tests := []struct { + name string + property *MapProperty + input any + expectedValue any + wantError bool + }{ + { + name: "parse with sub-properties", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + Properties: property.PropertyMap{ + "subKey": &StringProperty{PropertyDetails: property.PropertyDetails{Path: "subKey"}}, + }, + }, + input: map[string]any{"subKey": "value"}, + expectedValue: map[string]any{"subKey": "value"}, + wantError: false, + }, + { + name: "parse with key and value properties", + property: &MapProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + KeyProperty: &StringProperty{}, + ValueProperty: &IntProperty{}, + }, + input: map[string]any{"key": "42"}, + expectedValue: map[string]any{"key": 42}, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := DefaultExecutionContext{} + result, err := tt.property.Parse(tt.input, ctx, nil) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedValue, result) + } + }) + } +} + +func Test_MapProperty_Contains(t *testing.T) { + tests := []struct { + name string + property *MapProperty + value any + contains any + expected bool + }{ + { + name: "contains key-value pair", + property: &MapProperty{}, + value: map[string]any{"key1": "value1", "key2": "value2"}, + contains: map[string]any{"key1": "value1"}, + expected: true, + }, + { + name: "does not contain key-value pair", + property: &MapProperty{}, + value: map[string]any{"key1": "value1", "key2": "value2"}, + contains: map[string]any{"key3": "value3"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.property.Contains(tt.value, tt.contains) + assert.Equal(t, tt.expected, result) + }) + } +} + +func Test_MapProperty_Type(t *testing.T) { + tests := []struct { + name string + property *MapProperty + expected string + }{ + { + name: "type with key and value properties", + property: &MapProperty{ + KeyProperty: &StringProperty{}, + ValueProperty: &IntProperty{}, + }, + expected: "map(string,int)", + }, + { + name: "type without key and value properties", + property: &MapProperty{}, + expected: "map", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.property.Type() + assert.Equal(t, tt.expected, result) + }) + } +} + +func Test_MapProperty_Key(t *testing.T) { + keyProp := &StringProperty{} + prop := &MapProperty{KeyProperty: keyProp} + assert.Same(t, keyProp, prop.Key()) +} + +func Test_MapProperty_Value(t *testing.T) { + valueProp := &IntProperty{} + prop := &MapProperty{ValueProperty: valueProp} + assert.Same(t, valueProp, prop.Value()) +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/pkg/k2/constructs/template/properties/path_property.go b/pkg/k2/constructs/template/properties/path_property.go new file mode 100644 index 000000000..4023437b0 --- /dev/null +++ b/pkg/k2/constructs/template/properties/path_property.go @@ -0,0 +1,161 @@ +package properties + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + + "github.com/klothoplatform/klotho/pkg/collectionutil" +) + +type ( + PathProperty struct { + SanitizeTmpl *property.SanitizeTmpl + AllowedValues []string + SharedPropertyFields + property.PropertyDetails + RelativeTo string + } +) + +func (p *PathProperty) SetProperty(properties construct.Properties, value any) error { + strVal, ok := value.(string) + if !ok { + return fmt.Errorf("value %v is not a string", value) + } + if strVal == "" { + return properties.SetProperty(p.Path, "") + } + + path, err := resolvePath(strVal, p.RelativeTo) + if err != nil { + return err + } + return properties.SetProperty(p.Path, path) +} + +func (p *PathProperty) AppendProperty(properties construct.Properties, value any) error { + return p.SetProperty(properties, value) +} + +func (p *PathProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(p.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + return properties.RemoveProperty(p.Path, nil) +} + +func (p *PathProperty) Details() *property.PropertyDetails { + return &p.PropertyDetails +} + +func (p *PathProperty) Clone() property.Property { + clone := *p + return &clone +} + +func (p *PathProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if p.DefaultValue == nil { + return p.ZeroValue(), nil + } + return p.Parse(p.DefaultValue, ctx, data) +} + +func (p *PathProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + strVal := "" + switch val := value.(type) { + case string: + err := ctx.ExecuteUnmarshal(val, data, &strVal) + if err != nil { + return nil, err + } + strVal = val + case int, int32, int64, float32, float64, bool: + strVal = fmt.Sprintf("%v", val) + default: + return nil, fmt.Errorf("could not parse string property: invalid string value %v (%[1]T)", value) + } + if strVal == "" { + return "", nil + } + return resolvePath(strVal, p.RelativeTo) +} + +func (p *PathProperty) ZeroValue() any { + return "" +} + +func (p *PathProperty) Contains(value any, contains any) bool { + vString, ok := value.(string) + if !ok { + return false + } + cString, ok := contains.(string) + if !ok { + return false + } + return strings.Contains(vString, cString) +} + +func (p *PathProperty) Type() string { + return "string" +} + +func (p *PathProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if p.Required { + return fmt.Errorf(property.ErrRequiredProperty, p.Path) + } + return nil + } + stringVal, ok := value.(string) + if !ok { + return fmt.Errorf("value %v is not a string", value) + } + + if len(p.AllowedValues) > 0 && !collectionutil.Contains(p.AllowedValues, stringVal) { + return fmt.Errorf("value %s is not allowed. allowed values are %s", stringVal, p.AllowedValues) + } + + if p.SanitizeTmpl != nil { + return p.SanitizeTmpl.Check(stringVal) + } + return nil +} + +func (p *PathProperty) SubProperties() property.PropertyMap { + return nil +} + +func resolvePath(path string, basePath string) (string, error) { + // If the path is absolute, return it as is + if filepath.IsAbs(path) { + return path, nil + } + + // Otherwise, make it relative to the base path or the current working directory + if basePath == "" { + var err error + basePath, err = os.Getwd() + if err != nil { + return "", fmt.Errorf("could not get working directory") + } + } + abs, err := filepath.Abs(filepath.Join(basePath, path)) + if err != nil { + return "", fmt.Errorf("could not resolve path %s: %w", path, err) + } + return abs, nil +} diff --git a/pkg/k2/constructs/template/properties/properties.go b/pkg/k2/constructs/template/properties/properties.go new file mode 100644 index 000000000..68124a2e3 --- /dev/null +++ b/pkg/k2/constructs/template/properties/properties.go @@ -0,0 +1,138 @@ +package properties + +import ( + "bytes" + "encoding" + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + "text/template" + + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/klothoplatform/klotho/pkg/k2/model" + "github.com/klothoplatform/klotho/pkg/templateutils" +) + +type ( + DefaultExecutionContext struct{} +) + +// UnmarshalFunc decodes data into the supplied pointer, v +type UnmarshalFunc func(data *bytes.Buffer, v any) error + +func (d DefaultExecutionContext) ExecuteUnmarshal(tmpl string, data any, v any) error { + parsedTemplate, err := template.New("tmpl").Funcs(templateutils.WithCommonFuncs(template.FuncMap{})).Parse(tmpl) + if err != nil { + return err + } + + return ExecuteTemplateUnmarshal(parsedTemplate, data, v, d.Unmarshal) +} + +func (d DefaultExecutionContext) Unmarshal(data *bytes.Buffer, v any) error { + return UnmarshalAny(data, v) +} + +// ExecuteTemplateUnmarshal executes the [template.Template], t, using data and unmarshals the value into v +func ExecuteTemplateUnmarshal( + t *template.Template, + data any, + v any, + unmarshal UnmarshalFunc, +) error { + buf := new(bytes.Buffer) + if err := t.Execute(buf, data); err != nil { + return err + } + + if err := unmarshal(buf, v); err != nil { + return fmt.Errorf("cannot decode template result '%s' into %T", buf, v) + } + + return nil +} + +func UnmarshalJSON(data *bytes.Buffer, outputRefValue any) error { + dec := json.NewDecoder(data) + return dec.Decode(outputRefValue) +} + +// UnmarshalAny decodes the template result into a primitive or a struct that implements [encoding.TextUnmarshaler]. +// As a fallback, it tries to unmarshal the result using [json.Unmarshal]. +// If v is a pointer, it will be set to the decoded value. +func UnmarshalAny(data *bytes.Buffer, v any) error { + // trim the spaces, so you don't have to sprinkle the templates with `{{-` and `-}}` (the `-` trims spaces) + bstr := strings.TrimSpace(data.String()) + switch value := v.(type) { + case *string: + *value = bstr + return nil + + case *[]byte: + *value = []byte(bstr) + return nil + + case *bool: + result := strings.ToLower(bstr) + // If the input (eg 'field') is nil and the 'if' statement just uses '{{ inputs "field" }}', + // then the string result will be ''. + // Make sure we don't interpret that as a true condition. + *value = result != "" && result != "" && strings.ToLower(result) != "false" + return nil + case *int: + i, err := strconv.Atoi(bstr) + if err != nil { + return err + } + *value = i + return nil + case *float64: + f, err := strconv.ParseFloat(bstr, 64) + if err != nil { + return err + } + *value = f + return nil + case *float32: + f, err := strconv.ParseFloat(bstr, 32) + if err != nil { + return err + } + *value = float32(f) + return nil + + case encoding.TextUnmarshaler: + // notably, this handles `construct.ResourceId` and `construct.IaCValue` + return value.UnmarshalText([]byte(bstr)) + } + + resultStr := reflect.ValueOf(data.String()) + valueRefl := reflect.ValueOf(v).Elem() + if resultStr.Type().AssignableTo(valueRefl.Type()) { + // this covers alias types like `type MyString string` + valueRefl.Set(resultStr) + return nil + } + + err := json.Unmarshal([]byte(bstr), v) + if err == nil { + return nil + } + + return err + +} + +func ExecuteUnmarshalAsURN(ctx property.ExecutionContext, tmpl string, data any) (model.URN, error) { + var selector model.URN + err := ctx.ExecuteUnmarshal(tmpl, data, &selector) + if err != nil { + return selector, err + } + if selector.IsZero() { + return selector, fmt.Errorf("selector '%s' is zero", tmpl) + } + return selector, nil +} diff --git a/pkg/k2/constructs/template/properties/set_property.go b/pkg/k2/constructs/template/properties/set_property.go new file mode 100644 index 000000000..e2fc8df0c --- /dev/null +++ b/pkg/k2/constructs/template/properties/set_property.go @@ -0,0 +1,230 @@ +package properties + +import ( + "errors" + "fmt" + "reflect" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + + "github.com/klothoplatform/klotho/pkg/set" +) + +type ( + SetProperty struct { + MinLength *int + MaxLength *int + ItemProperty property.Property + Properties property.PropertyMap + SharedPropertyFields + property.PropertyDetails + } +) + +func (s *SetProperty) SetProperty(properties construct.Properties, value any) error { + switch val := value.(type) { + case set.HashedSet[string, any]: + return properties.SetProperty(s.Path, val) + } + + if val, ok := value.(set.HashedSet[string, any]); ok { + return properties.SetProperty(s.Path, val) + } + return fmt.Errorf("invalid set value %v", value) +} + +func (s *SetProperty) AppendProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(s.Path) + if err != nil && !errors.Is(err, construct.ErrPropertyDoesNotExist) { + return err + } + if propVal == nil { + if val, ok := value.(set.HashedSet[string, any]); ok { + return s.SetProperty(properties, val) + } + } + return properties.AppendProperty(s.Path, value) +} + +func (s *SetProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(s.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + propSet, ok := propVal.(set.HashedSet[string, any]) + if !ok { + return errors.New("invalid set value") + } + if val, ok := value.(set.HashedSet[string, any]); ok { + for _, v := range val.ToSlice() { + propSet.Remove(v) + } + } else { + return fmt.Errorf("invalid set value %v", value) + } + return s.SetProperty(properties, propSet) +} + +func (s *SetProperty) Details() *property.PropertyDetails { + return &s.PropertyDetails +} + +func (s *SetProperty) Clone() property.Property { + var itemProp property.Property + if s.ItemProperty != nil { + itemProp = s.ItemProperty.Clone() + } + var props property.PropertyMap + if s.Properties != nil { + props = s.Properties.Clone() + } + clone := *s + clone.ItemProperty = itemProp + clone.Properties = props + return &clone +} + +func (s *SetProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if s.DefaultValue == nil { + return nil, nil + } + return s.Parse(s.DefaultValue, ctx, data) +} + +func (s *SetProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + var result = set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + Less: func(s1, s2 string) bool { + return s1 < s2 + }, + } + + var vals []any + if valSet, ok := value.(set.HashedSet[string, any]); ok { + vals = valSet.ToSlice() + } else if val, ok := value.([]any); ok { + vals = val + } else { + // before we fail, check to see if the entire value is a template + if strVal, ok := value.(string); ok { + err := ctx.ExecuteUnmarshal(strVal, data, &vals) + if err != nil { + return nil, err + } + } + } + + for _, v := range vals { + if len(s.Properties) != 0 { + m := MapProperty{Properties: s.Properties} + val, err := m.Parse(v, ctx, data) + if err != nil { + return nil, err + } + result.Add(val) + } else { + val, err := s.ItemProperty.Parse(v, ctx, data) + if err != nil { + return nil, err + } + result.Add(val) + } + } + return result, nil +} + +func (s *SetProperty) ZeroValue() any { + return nil +} + +func (s *SetProperty) Contains(value any, contains any) bool { + valSet, ok := value.(set.HashedSet[string, any]) + if !ok { + return false + } + + for _, val := range valSet.M { + if reflect.DeepEqual(contains, val) { + return true + } + } + + return false +} + +func (s *SetProperty) Type() string { + if s.ItemProperty != nil { + return fmt.Sprintf("set(%s)", s.ItemProperty.Type()) + } + return "set" +} + +func (s *SetProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if s.Required { + return fmt.Errorf(property.ErrRequiredProperty, s.Path) + } + return nil + } + setVal, ok := value.(set.HashedSet[string, any]) + if !ok { + return fmt.Errorf("could not validate set property: invalid set value %v", value) + } + if s.MinLength != nil { + if setVal.Len() < *s.MinLength { + return fmt.Errorf("value %s is too short. minimum length is %d", setVal.M, *s.MinLength) + } + } + if s.MaxLength != nil { + if setVal.Len() > *s.MaxLength { + return fmt.Errorf("value %s is too long. maximum length is %d", setVal.M, *s.MaxLength) + } + } + + // Only validate values if its a primitive list, otherwise let the sub properties handle their own validation + if s.ItemProperty != nil { + var errs error + hasSanitized := false + validSet := set.HashedSet[string, any]{Hasher: setVal.Hasher} + for _, item := range setVal.ToSlice() { + if err := s.ItemProperty.Validate(properties, item); err != nil { + var sanitizeErr *property.SanitizeError + if errors.As(err, &sanitizeErr) { + validSet.Add(sanitizeErr.Sanitized) + hasSanitized = true + } else { + errs = errors.Join(errs, fmt.Errorf("invalid item %v: %v", item, err)) + } + } else { + validSet.Add(item) + } + } + if errs != nil { + return errs + } + if hasSanitized { + return &property.SanitizeError{ + Input: setVal, + Sanitized: validSet, + } + } + } + return nil +} + +func (s *SetProperty) SubProperties() property.PropertyMap { + return s.Properties +} + +func (s *SetProperty) Item() property.Property { + return s.ItemProperty +} diff --git a/pkg/k2/constructs/template/properties/set_property_test.go b/pkg/k2/constructs/template/properties/set_property_test.go new file mode 100644 index 000000000..e9fc697bd --- /dev/null +++ b/pkg/k2/constructs/template/properties/set_property_test.go @@ -0,0 +1,328 @@ +package properties + +import ( + "fmt" + "testing" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/klothoplatform/klotho/pkg/knowledgebase" + "github.com/klothoplatform/klotho/pkg/set" + "github.com/stretchr/testify/assert" +) + +func Test_SetProperty_SetProperty(t *testing.T) { + tests := []struct { + name string + property *SetProperty + input any + wantError bool + }{ + { + name: "valid set value", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "item1": "item1", + "item2": "item2", + }, + }, + wantError: false, + }, + { + name: "invalid list value", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: []any{"item1", "item2"}, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_SetProperty_ZeroValue(t *testing.T) { + assert := assert.New(t) + property := &SetProperty{} + assert.Nil(property.ZeroValue()) +} + +// Testing the Details method +func Test_SetProperty_Details(t *testing.T) { + assert := assert.New(t) + property := &SetProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +func Test_SetProperty_Clone(t *testing.T) { + property := &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + ItemProperty: &StringProperty{}, + } + clone := property.Clone() + assert.Equal(t, property, clone) +} + +func Test_SetProperty_AppendProperty(t *testing.T) { + tests := []struct { + name string + property *SetProperty + properties construct.Properties + input any + wantError bool + }{ + { + name: "append valid set value", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "item1": "item1", + "item2": "item2", + }, + }, + wantError: false, + }, + { + name: "append list value", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: []any{"item1", "item2"}, + wantError: false, + }, + { + // This test documents a bug in the AppendProperty method + name: "append invalid map value", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + properties: construct.Properties{}, + input: map[string]any{"key": "value"}, + wantError: false, // This should be true if the bug is fixed + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.property.AppendProperty(tt.properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_RemoveSetProperty(t *testing.T) { + tests := []struct { + name string + property *SetProperty + properties construct.Properties + value any + expected set.HashedSet[string, any] + }{ + { + name: "existing property", + properties: map[string]any{ + "test": set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "test2": "test2", + "test1": "test1", + }, + }, + }, + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{ + Path: "test", + }, + }, + value: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "test2": "test2", + }, + }, + expected: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "test1": "test1", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + err := tt.property.RemoveProperty(tt.properties, tt.value) + if !assert.NoError(err) { + return + } + assert.Equal(tt.expected.M, tt.properties[tt.property.Path].(set.HashedSet[string, any]).M) + }) + } +} + +func Test_SetParse(t *testing.T) { + tests := []struct { + name string + property *SetProperty + ctx knowledgebase.DynamicValueContext + data knowledgebase.DynamicValueData + value any + expected set.HashedSet[string, any] + wantErr bool + }{ + { + name: "set property", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{ + Path: "test", + }, + ItemProperty: &StringProperty{}, + }, + value: []any{"test1", "test2"}, + expected: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "test1": "test1", + "test2": "test2", + }, + }, + }, + { + name: "set property as template", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{ + Path: "test", + }, + ItemProperty: &StringProperty{}, + }, + value: []any{"{{ \"test1\" }}", "{{ \"test2\" }}"}, + expected: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "test1": "test1", + "test2": "test2", + }, + }, + }, + { + name: "non set throws error", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{ + Path: "test", + }, + }, + value: "test", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + ctx := DefaultExecutionContext{} + actual, err := tt.property.Parse(tt.value, ctx, nil) + if tt.wantErr { + assert.Error(err) + return + } + if !assert.NoError(err) { + return + } + assert.Equal(actual.(set.HashedSet[string, any]).M, tt.expected.M, "expected %v, got %v", tt.expected, actual) + }) + } +} +func Test_SetProperty_Contains(t *testing.T) { + tests := []struct { + name string + property *SetProperty + value any + expected bool + }{ + { + name: "set contains value", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + ItemProperty: &StringProperty{}, + }, + value: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "test": "test", + }, + }, + expected: true, + }, + { + name: "set does not contain value", + property: &SetProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + ItemProperty: &StringProperty{}, + }, + value: set.HashedSet[string, any]{ + Hasher: func(s any) string { + return fmt.Sprintf("%v", s) + }, + M: map[string]any{ + "other": "other", + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.property.Contains(tt.value, "test") + assert.Equal(t, tt.expected, result) + }) + } +} + +func Test_SetProperty_Type(t *testing.T) { + assert := assert.New(t) + property := &SetProperty{} + assert.Equal("set", property.Type()) + property2 := &SetProperty{ + ItemProperty: &StringProperty{}, + } + assert.Equal("set(string)", property2.Type()) +} diff --git a/pkg/k2/constructs/template/properties/shared_property_fields.go b/pkg/k2/constructs/template/properties/shared_property_fields.go new file mode 100644 index 000000000..1d7a62d21 --- /dev/null +++ b/pkg/k2/constructs/template/properties/shared_property_fields.go @@ -0,0 +1,41 @@ +package properties + +import ( + "bytes" + "fmt" + "text/template" + + "github.com/klothoplatform/klotho/pkg/construct" +) + +type ( + SharedPropertyFields struct { + DefaultValue any `json:"default_value" yaml:"default_value"` + ValidityChecks []PropertyValidityCheck + } + + PropertyValidityCheck struct { + template *template.Template + } + ValidityCheckData struct { + Properties construct.Properties `json:"properties" yaml:"properties"` + Value any `json:"value" yaml:"value"` + } +) + +func (p *PropertyValidityCheck) Validate(value any, properties construct.Properties) error { + var buff bytes.Buffer + data := ValidityCheckData{ + Properties: properties, + Value: value, + } + err := p.template.Execute(&buff, data) + if err != nil { + return err + } + result := buff.String() + if result != "" { + return fmt.Errorf("invalid value %v: %s", value, result) + } + return nil +} diff --git a/pkg/k2/constructs/template/properties/string_property.go b/pkg/k2/constructs/template/properties/string_property.go new file mode 100644 index 000000000..218dcc8f8 --- /dev/null +++ b/pkg/k2/constructs/template/properties/string_property.go @@ -0,0 +1,122 @@ +package properties + +import ( + "errors" + "fmt" + "strings" + + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + + "github.com/klothoplatform/klotho/pkg/collectionutil" +) + +type ( + StringProperty struct { + SanitizeTmpl *property.SanitizeTmpl + AllowedValues []string + SharedPropertyFields + property.PropertyDetails + } +) + +func (str *StringProperty) SetProperty(properties construct.Properties, value any) error { + if val, ok := value.(string); ok { + return properties.SetProperty(str.Path, val) + } else if val, ok := value.(construct.PropertyRef); ok { + return properties.SetProperty(str.Path, val) + } + return fmt.Errorf("could not set string property: invalid string value %v", value) +} + +func (str *StringProperty) AppendProperty(properties construct.Properties, value any) error { + return str.SetProperty(properties, value) +} + +func (str *StringProperty) RemoveProperty(properties construct.Properties, value any) error { + propVal, err := properties.GetProperty(str.Path) + if errors.Is(err, construct.ErrPropertyDoesNotExist) { + return nil + } + if err != nil { + return err + } + if propVal == nil { + return nil + } + return properties.RemoveProperty(str.Path, nil) +} + +func (str *StringProperty) Details() *property.PropertyDetails { + return &str.PropertyDetails +} + +func (str *StringProperty) Clone() property.Property { + clone := *str + return &clone +} + +func (str *StringProperty) GetDefaultValue(ctx property.ExecutionContext, data any) (any, error) { + if str.DefaultValue == nil { + return nil, nil + } + return str.Parse(str.DefaultValue, ctx, data) +} + +func (str *StringProperty) Parse(value any, ctx property.ExecutionContext, data any) (any, error) { + switch val := value.(type) { + case string: + err := ctx.ExecuteUnmarshal(val, data, &val) + return val, err + + case int, int32, int64, float32, float64, bool: + return fmt.Sprintf("%v", val), nil + } + return nil, fmt.Errorf("could not parse string property: invalid string value %v (%[1]T)", value) +} + +func (str *StringProperty) ZeroValue() any { + return "" +} + +func (str *StringProperty) Contains(value any, contains any) bool { + vString, ok := value.(string) + if !ok { + return false + } + cString, ok := contains.(string) + if !ok { + return false + } + return strings.Contains(vString, cString) +} + +func (str *StringProperty) Type() string { + return "string" +} + +func (str *StringProperty) Validate(properties construct.Properties, value any) error { + if value == nil { + if str.Required { + return fmt.Errorf(property.ErrRequiredProperty, str.Path) + } + return nil + } + stringVal, ok := value.(string) + if !ok { + return fmt.Errorf("value %v is not a string", value) + } + + if len(str.AllowedValues) > 0 && !collectionutil.Contains(str.AllowedValues, stringVal) { + return fmt.Errorf("value %s is not allowed. allowed values are %s", stringVal, str.AllowedValues) + } + + if str.SanitizeTmpl != nil { + return str.SanitizeTmpl.Check(stringVal) + } + return nil +} + +func (str *StringProperty) SubProperties() property.PropertyMap { + return nil +} diff --git a/pkg/k2/constructs/template/properties/string_property_test.go b/pkg/k2/constructs/template/properties/string_property_test.go new file mode 100644 index 000000000..56b587387 --- /dev/null +++ b/pkg/k2/constructs/template/properties/string_property_test.go @@ -0,0 +1,181 @@ +package properties + +import ( + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_StringPropertySetProperty(t *testing.T) { + tests := []struct { + name string + property *StringProperty + input any + wantError bool + }{ + { + name: "valid string value", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: "valid_string", + wantError: false, + }, + { + name: "invalid value type", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + input: 123, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + err := tt.property.SetProperty(properties, tt.input) + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_StringPropertyZeroValue(t *testing.T) { + assert := assert.New(t) + property := &StringProperty{} + assert.Equal("", property.ZeroValue()) +} + +func Test_StringPropertyDetails(t *testing.T) { + assert := assert.New(t) + property := &StringProperty{} + assert.Same(&property.PropertyDetails, property.Details()) +} + +func Test_StringPropertyContains(t *testing.T) { + tests := []struct { + name string + property *StringProperty + value any + contains any + expected bool + }{ + { + name: "string contains value", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + value: "hello world", + contains: "hello", + expected: true, + }, + { + name: "string does not contain value", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + value: "hello world", + contains: "goodbye", + expected: false, + }, + { + name: "non-string value", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + value: 123, + contains: "1", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.property.Contains(tt.value, tt.contains) + assert.Equal(t, tt.expected, result) + }) + } +} + +func Test_StringPropertyType(t *testing.T) { + assert := assert.New(t) + property := &StringProperty{} + assert.Equal("string", property.Type()) +} + +func Test_StringPropertyValidate(t *testing.T) { + tests := []struct { + name string + property *StringProperty + value any + wantErr bool + sanitizeTemplate string + allowedValues []string + }{ + { + name: "valid string in allowed values", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + AllowedValues: []string{"allowed_value"}, + }, + value: "allowed_value", + wantErr: false, + }, + { + name: "string not in allowed values", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + AllowedValues: []string{"allowed_value"}, + }, + value: "disallowed_value", + wantErr: true, + }, + { + name: "invalid value type", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + value: 123, + wantErr: true, + }, + { + name: "sanitized string value", + property: &StringProperty{ + PropertyDetails: property.PropertyDetails{Path: "test"}, + }, + sanitizeTemplate: "{{ . | upper }}", + value: "TEST", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + properties := construct.Properties{} + if tt.sanitizeTemplate != "" { + tmpl, err := property.NewSanitizationTmpl(tt.name, tt.sanitizeTemplate) + if !assert.NoError(t, err) { + return + } + tt.property.SanitizeTmpl = tmpl + } + err := tt.property.Validate(properties, tt.value) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_StringPropertySubProperties(t *testing.T) { + assert := assert.New(t) + property := &StringProperty{} + assert.Nil(property.SubProperties()) +} diff --git a/pkg/k2/constructs/template/property/construct_type.go b/pkg/k2/constructs/template/property/construct_type.go new file mode 100644 index 000000000..f066e7d3a --- /dev/null +++ b/pkg/k2/constructs/template/property/construct_type.go @@ -0,0 +1,67 @@ +package property + +import ( + "fmt" + "regexp" + "strings" + + "github.com/klothoplatform/klotho/pkg/k2/model" + "gopkg.in/yaml.v3" +) + +type ConstructReference struct { + URN model.URN `yaml:"urn" json:"urn"` + Path string `yaml:"path" json:"path"` +} + +type ConstructType struct { + Package string `yaml:"package"` + Name string `yaml:"name"` +} + +var constructTypeRegexp = regexp.MustCompile(`^(?:([\w-]+)\.)+([\w-]+)$`) + +func (c *ConstructType) UnmarshalYAML(value *yaml.Node) error { + var typeString string + err := value.Decode(&typeString) + if err != nil { + return fmt.Errorf("failed to decode construct type: %w", err) + } + + if !constructTypeRegexp.MatchString(typeString) { + return fmt.Errorf("invalid construct type: %s", typeString) + } + + lastDot := strings.LastIndex(typeString, ".") + c.Name = typeString[lastDot+1:] + c.Package = typeString[:lastDot] + + return nil +} + +func (c *ConstructType) String() string { + return fmt.Sprintf("%s.%s", c.Package, c.Name) +} + +func (c *ConstructType) FromString(id string) error { + parts := strings.Split(id, ".") + if len(parts) < 2 { + return fmt.Errorf("invalid construct template id: %s", id) + } + c.Package = strings.Join(parts[:len(parts)-1], ".") + c.Name = parts[len(parts)-1] + return nil +} + +func ParseConstructType(id string) (ConstructType, error) { + var c ConstructType + err := c.FromString(id) + return c, err +} + +func (c *ConstructType) FromURN(urn model.URN) error { + if urn.Type != "construct" { + return fmt.Errorf("invalid urn type: %s", urn.Type) + } + return c.FromString(urn.Subtype) +} diff --git a/pkg/k2/constructs/template/property/construct_type_test.go b/pkg/k2/constructs/template/property/construct_type_test.go new file mode 100644 index 000000000..88b24e525 --- /dev/null +++ b/pkg/k2/constructs/template/property/construct_type_test.go @@ -0,0 +1,50 @@ +package property + +import ( + "testing" + + "github.com/klothoplatform/klotho/pkg/k2/model" + "github.com/stretchr/testify/assert" +) + +func TestConstructType_FromURN(t *testing.T) { + tests := []struct { + name string + input string + expected ConstructType + wantErr bool + }{ + { + name: "Valid URN", + input: "urn:accountid:project:dev::construct/package.name", + expected: ConstructType{Package: "package", Name: "name"}, + wantErr: false, + }, + { + name: "Invalid URN type", + input: "urn:accountid:project:dev::other/package.name", + wantErr: true, + }, + { + name: "Invalid URN format", + input: "urn:accountid:project:dev::construct/invalid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ctId ConstructType + urn, err := model.ParseURN(tt.input) + if assert.NoError(t, err) { + err = ctId.FromURN(*urn) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, ctId) + } + } + }) + } +} diff --git a/pkg/k2/constructs/template/property/interfaces.go b/pkg/k2/constructs/template/property/interfaces.go new file mode 100644 index 000000000..9d7a3e310 --- /dev/null +++ b/pkg/k2/constructs/template/property/interfaces.go @@ -0,0 +1,71 @@ +package property + +import ( + "bytes" + + "github.com/klothoplatform/klotho/pkg/construct" +) + +type ( + // ExecutionContext defines the methods to execute a go template and decode the result into a value + ExecutionContext interface { + // ExecuteUnmarshal executes the template tmpl using data as input and unmarshals the value into v + ExecuteUnmarshal(tmpl string, data any, v any) error + // Unmarshal unmarshals the template result into a value + Unmarshal(data *bytes.Buffer, v any) error + } + + Property interface { + // SetProperty sets the value of the property on the properties + SetProperty(properties construct.Properties, value any) error + // AppendProperty appends the value to the property on the properties + AppendProperty(properties construct.Properties, value any) error + // RemoveProperty removes the value from the property on the properties + RemoveProperty(properties construct.Properties, value any) error + // Details returns the property details for the property + Details() *PropertyDetails + // Clone returns a clone of the property + Clone() Property + + // Type returns the string representation of the property's type, as it should appear in a template + Type() string + // GetDefaultValue returns the default value for the property, + // pertaining to the specific data being passed in for execution + GetDefaultValue(ctx ExecutionContext, data any) (any, error) + // Validate ensures the value is valid for the property to `Set` (not `Append` for collection types) + // and returns an error if it is not + Validate(properties construct.Properties, value any) error + // SubProperties returns the sub properties of the input, if any. + // This is used for inputs that are complex structures, such as lists, sets, or maps + SubProperties() PropertyMap + // Parse parses a given value to ensure it is the correct type for the property. + // If the given value cannot be converted to the respective property type an error is returned. + // The returned value will always be the correct type for the property + Parse(value any, ctx ExecutionContext, data any) (any, error) + // ZeroValue returns the zero value for the property type + ZeroValue() any + // Contains returns true if the value contains the given value + Contains(value any, contains any) bool + } + + MapProperty interface { + // Key returns the property representing the keys of the map + Key() Property + // Value returns the property representing the values of the map + Value() Property + } + + CollectionProperty interface { + // Item returns the structure of the items within the collection + Item() Property + } + + Properties interface { + Clone() Properties + ForEach(c construct.Properties, f func(p Property) error) error + Get(key string) (Property, bool) + Set(key string, value Property) + Remove(key string) + AsMap() map[string]Property + } +) diff --git a/pkg/k2/constructs/template/property/properties.go b/pkg/k2/constructs/template/property/properties.go new file mode 100644 index 000000000..73866acb9 --- /dev/null +++ b/pkg/k2/constructs/template/property/properties.go @@ -0,0 +1,151 @@ +package property + +import ( + "errors" + "fmt" + "github.com/klothoplatform/klotho/pkg/construct" + "github.com/klothoplatform/klotho/pkg/set" + "reflect" + "sort" + "strings" +) + +// PropertyMap is a map of properties that can be used to represent complex data structures in a template +// Wrap this in a struct that implements the [Properties] interface when using it in a template +type PropertyMap map[string]Property + +func (p PropertyMap) Clone() PropertyMap { + newProps := make(PropertyMap, len(p)) + for k, v := range p { + newProps[k] = v.Clone() + } + return newProps +} + +func (p PropertyMap) Get(key string) (Property, bool) { + value, exists := p[key] + return value, exists +} + +func (p PropertyMap) Set(key string, value Property) { + p[key] = value +} + +func (p PropertyMap) Remove(key string) { + delete(p, key) +} + +func (p PropertyMap) ForEach(c construct.Properties, f func(p Property) error) error { + queue := []PropertyMap{p} + var props PropertyMap + var errs error + for len(queue) > 0 { + props, queue = queue[0], queue[1:] + + propKeys := make([]string, 0, len(props)) + for k := range props { + propKeys = append(propKeys, k) + } + sort.Strings(propKeys) + + for _, key := range propKeys { + prop := props[key] + err := f(prop) + if err != nil { + if errors.Is(err, ErrStopWalk) { + return nil + } + errs = errors.Join(errs, err) + continue + } + + if strings.HasPrefix(prop.Type(), "list") || strings.HasPrefix(prop.Type(), "set") { + p, err := c.GetProperty(prop.Details().Path) + if err != nil || p == nil { + continue + } + // Because lists/sets will start as empty, do not recurse into their sub-properties if it's not set. + // To allow for defaults within list objects and operational rules to be run, + // we will look inside the property to see if there are values. + if strings.HasPrefix(prop.Type(), "list") { + length := reflect.ValueOf(p).Len() + for i := 0; i < length; i++ { + subProperties := make(PropertyMap) + for subK, subProp := range prop.SubProperties() { + propTemplate := subProp.Clone() + ReplacePath(propTemplate, prop.Details().Path, fmt.Sprintf("%s[%d]", prop.Details().Path, i)) + subProperties[subK] = propTemplate + } + if len(subProperties) > 0 { + queue = append(queue, subProperties) + } + } + } else if strings.HasPrefix(prop.Type(), "set") { + hs, ok := p.(set.HashedSet[string, any]) + if !ok { + errs = errors.Join(errs, fmt.Errorf("could not cast property to set")) + continue + } + for i := range hs.ToSlice() { + subProperties := make(PropertyMap) + for subK, subProp := range prop.SubProperties() { + propTemplate := subProp.Clone() + ReplacePath(propTemplate, prop.Details().Path, fmt.Sprintf("%s[%d]", prop.Details().Path, i)) + subProperties[subK] = propTemplate + } + if len(subProperties) > 0 { + queue = append(queue, subProperties) + } + } + + } + } else if prop.SubProperties() != nil { + queue = append(queue, prop.SubProperties()) + } + } + } + return errs +} + +func GetProperty(properties PropertyMap, path string) Property { + fields := strings.Split(path, ".") +FIELDS: + for i, field := range fields { + currFieldName := strings.Split(field, "[")[0] + found := false + for name, property := range properties { + if name != currFieldName { + continue + } + found = true + if len(fields) == i+1 { + // use a clone resource so we can modify the name in case anywhere in the path + // has index strings or map keys + clone := property.Clone() + details := clone.Details() + details.Path = path + return clone + } else { + properties = property.SubProperties() + if len(properties) == 0 { + if mp, ok := property.(MapProperty); ok { + clone := mp.Value().Clone() + details := clone.Details() + details.Path = path + return clone + } else if cp, ok := property.(CollectionProperty); ok { + clone := cp.Item().Clone() + details := clone.Details() + details.Path = path + return clone + } + } + } + continue FIELDS + } + if !found { + return nil + } + } + return nil +} diff --git a/pkg/k2/constructs/template/property/property_details.go b/pkg/k2/constructs/template/property/property_details.go new file mode 100644 index 000000000..d90856f27 --- /dev/null +++ b/pkg/k2/constructs/template/property/property_details.go @@ -0,0 +1,21 @@ +package property + +type ( + // PropertyDetails defines the common details of a property + PropertyDetails struct { + Name string `json:"name" yaml:"name"` + // DefaultValue has to be any because it may be a template and it may be a value of the correct type + Namespace bool `yaml:"namespace"` + // Required defines if the property is required + Required bool `json:"required" yaml:"required"` + // ConfigurationDisabled defines if the property is allowed to be configured by the user + ConfigurationDisabled bool `json:"configuration_disabled" yaml:"configuration_disabled"` + // OperationalRule defines a rule that is executed at runtime to determine the value of the property + //OperationalRule *PropertyRule `json:"operational_rule" yaml:"operational_rule"` + // Description is a description of the property. This is not used in the engine solving, + // but is metadata returned by the `ListResourceTypes` CLI command. + Description string `json:"description" yaml:"description"` + // Path is the path to the property in the resource + Path string `json:"-" yaml:"-"` + } +) diff --git a/pkg/k2/constructs/template/property/sanitization.go b/pkg/k2/constructs/template/property/sanitization.go new file mode 100644 index 000000000..fc892d667 --- /dev/null +++ b/pkg/k2/constructs/template/property/sanitization.go @@ -0,0 +1,94 @@ +package property + +import ( + "bytes" + "crypto/sha256" + "fmt" + "regexp" + "strings" + "sync" + "text/template" +) + +type ( + SanitizeTmpl struct { + template *template.Template + } + + // SanitizeError is returned when a value is sanitized if the input is not valid. The Sanitized field + // is always the same type as the Input field. + SanitizeError struct { + Input any + Sanitized any + } +) + +func NewSanitizationTmpl(name string, tmpl string) (*SanitizeTmpl, error) { + t, err := template.New(name + "/sanitize"). + Funcs(template.FuncMap{ + "replace": func(pattern, replace, name string) (string, error) { + re, err := regexp.Compile(pattern) + if err != nil { + return name, err + } + return re.ReplaceAllString(name, replace), nil + }, + + "length": func(min, max int, name string) string { + if len(name) < min { + return name + strings.Repeat("0", min-len(name)) + } + if len(name) > max { + base := name[:max-8] + h := sha256.New() + fmt.Fprint(h, name) + x := fmt.Sprintf("%x", h.Sum(nil)) + return base + x[:8] + } + return name + }, + + "lower": strings.ToLower, + "upper": strings.ToUpper, + }). + Parse(tmpl) + return &SanitizeTmpl{ + template: t, + }, err +} + +var sanitizeBufs = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, +} + +func (t SanitizeTmpl) Execute(value string) (string, error) { + buf := sanitizeBufs.Get().(*bytes.Buffer) + defer sanitizeBufs.Put(buf) + buf.Reset() + + err := t.template.Execute(buf, value) + if err != nil { + return value, fmt.Errorf("could not execute sanitize name template on %q: %w", value, err) + } + return strings.TrimSpace(buf.String()), nil +} + +func (t SanitizeTmpl) Check(value string) error { + sanitized, err := t.Execute(value) + if err != nil { + return err + } + if sanitized != value { + return &SanitizeError{ + Input: value, + Sanitized: sanitized, + } + } + return nil +} + +func (err SanitizeError) Error() string { + return fmt.Sprintf("invalid value %q, suggested value: %q", err.Input, err.Sanitized) +} diff --git a/pkg/k2/constructs/template/property/util.go b/pkg/k2/constructs/template/property/util.go new file mode 100644 index 000000000..28783456c --- /dev/null +++ b/pkg/k2/constructs/template/property/util.go @@ -0,0 +1,19 @@ +package property + +import ( + "errors" + "strings" +) + +const ErrRequiredProperty = "required property %s is not set" + +var ErrStopWalk = errors.New("stop walk") + +// ReplacePath runs a simple [strings.ReplaceAll] on the path of the property and all of its sub properties. +// NOTE: this mutates the property, so make sure to [Property.Clone] it first if you don't want that. +func ReplacePath(p Property, original, replacement string) { + p.Details().Path = strings.ReplaceAll(p.Details().Path, original, replacement) + for _, prop := range p.SubProperties() { + ReplacePath(prop, original, replacement) + } +} diff --git a/pkg/k2/constructs/template/resource_ref.go b/pkg/k2/constructs/template/resource_ref.go new file mode 100644 index 000000000..e3af576b5 --- /dev/null +++ b/pkg/k2/constructs/template/resource_ref.go @@ -0,0 +1,83 @@ +package template + +import ( + "fmt" + "github.com/klothoplatform/klotho/pkg/k2/model" + "github.com/klothoplatform/klotho/pkg/reflectutil" + "reflect" + "text/template" +) + +type ( + ResourceRef struct { + ConstructURN model.URN + ResourceKey string + Property string + Type ResourceRefType + } + + ResourceRefType string + InterpolationSourceKey string + + InterpolationSource interface { + GetPropertySource() *PropertySource + } + + PropertySource struct { + source reflect.Value + } + + TemplateFuncSupplier interface { + GetTemplateFuncs() template.FuncMap + } +) + +const ( + // ResourceRefTypeTemplate is a reference to a resource template and will be fully resolved prior to constraint generation + // e.g., ${resources:resourceName.property} or ${resources:resourceName} + ResourceRefTypeTemplate ResourceRefType = "template" + // ResourceRefTypeIaC is a reference to an infrastructure as code resource that will be resolved by the engine + // e.g., ${resources:resourceName#property} + ResourceRefTypeIaC ResourceRefType = "iac" + // ResourceRefTypeInterpolated is an initial interpolation reference to a resource. + // An interpolated value will be evaluated during initial processing and will be converted to one of the other types. + ResourceRefTypeInterpolated ResourceRefType = "interpolated" +) + +func (r *ResourceRef) String() string { + if r.Type == ResourceRefTypeIaC { + return fmt.Sprintf("%s#%s", r.ResourceKey, r.Property) + } + return r.ResourceKey +} + +func NewPropertySource(source any) *PropertySource { + var v reflect.Value + if sv, ok := source.(reflect.Value); ok { + v = sv + } else { + v = reflect.ValueOf(source) + } + return &PropertySource{ + source: v, + } +} + +func (p *PropertySource) GetProperty(key string) (value any, ok bool) { + v, err := reflectutil.GetField(p.source, key) + if err != nil || !v.IsValid() { + return nil, false + } + return v.Interface(), true +} + +func GetTypedProperty[T any](source *PropertySource, key string) (T, bool) { + var typedField T + v, ok := source.GetProperty(key) + + if !ok { + return typedField, false + } + + return reflectutil.GetTypedValue[T](v) +} diff --git a/pkg/k2/constructs/template_loader.go b/pkg/k2/constructs/template/template_loader.go similarity index 74% rename from pkg/k2/constructs/template_loader.go rename to pkg/k2/constructs/template/template_loader.go index b74fbf0fd..976d54afe 100644 --- a/pkg/k2/constructs/template_loader.go +++ b/pkg/k2/constructs/template/template_loader.go @@ -1,8 +1,9 @@ -package constructs +package template import ( "embed" "fmt" + "github.com/klothoplatform/klotho/pkg/k2/constructs/template/property" "path/filepath" "strings" "sync" @@ -14,17 +15,17 @@ import ( var templates embed.FS var ( - cachedConstructs = make(map[ConstructTemplateId]ConstructTemplate) + cachedConstructs = make(map[property.ConstructType]ConstructTemplate) cachedBindings = make(map[string]BindingTemplate) mu sync.Mutex ) -func loadConstructTemplate(id ConstructTemplateId) (ConstructTemplate, error) { +func LoadConstructTemplate(id property.ConstructType) (ConstructTemplate, error) { mu.Lock() defer mu.Unlock() - if template, ok := cachedConstructs[id]; ok { + if ct, ok := cachedConstructs[id]; ok { - return template, nil + return ct, nil } if !strings.HasPrefix(id.Package, "klotho.") { @@ -42,17 +43,17 @@ func loadConstructTemplate(id ConstructTemplateId) (ConstructTemplate, error) { return ConstructTemplate{}, fmt.Errorf("failed to read file: %w", err) } - var template ConstructTemplate - if err := yaml.Unmarshal(fileContent, &template); err != nil { + var ct ConstructTemplate + if err := yaml.Unmarshal(fileContent, &ct); err != nil { return ConstructTemplate{}, fmt.Errorf("failed to unmarshal yaml: %w", err) } - cachedConstructs[template.Id] = template + cachedConstructs[ct.Id] = ct - return template, nil + return ct, nil } -func loadBindingTemplate(owner ConstructTemplateId, from ConstructTemplateId, to ConstructTemplateId) (BindingTemplate, error) { +func LoadBindingTemplate(owner property.ConstructType, from property.ConstructType, to property.ConstructType) (BindingTemplate, error) { mu.Lock() defer mu.Unlock() if owner != from && owner != to { @@ -72,8 +73,8 @@ func loadBindingTemplate(owner ConstructTemplateId, from ConstructTemplateId, to cacheKey := fmt.Sprintf("%s/%s", owner.String(), bindingKey) - if template, ok := cachedBindings[cacheKey]; ok { - return template, nil + if ct, ok := cachedBindings[cacheKey]; ok { + return ct, nil } constructDir, err := getConstructTemplateDir(owner) @@ -89,18 +90,18 @@ func loadBindingTemplate(owner ConstructTemplateId, from ConstructTemplateId, to } // Unmarshal the YAML fileContent into a map - var template BindingTemplate - if err := yaml.Unmarshal(fileContent, &template); err != nil { + var ct BindingTemplate + if err := yaml.Unmarshal(fileContent, &ct); err != nil { return BindingTemplate{}, fmt.Errorf("failed to unmarshal yaml: %w", err) } // Cache the binding template for future use - cachedBindings[cacheKey] = template + cachedBindings[cacheKey] = ct - return template, nil + return ct, nil } -func getConstructTemplateDir(id ConstructTemplateId) (string, error) { +func getConstructTemplateDir(id property.ConstructType) (string, error) { // trim the klotho package prefix parts := strings.SplitN(id.Package, ".", 2) if len(parts) < 2 { diff --git a/pkg/k2/constructs/templates/aws/api/api.yaml b/pkg/k2/constructs/template/templates/aws/api/api.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/api/api.yaml rename to pkg/k2/constructs/template/templates/aws/api/api.yaml diff --git a/pkg/k2/constructs/template/templates/aws/api/bindings/to_klotho.aws.Container.yaml b/pkg/k2/constructs/template/templates/aws/api/bindings/to_klotho.aws.Container.yaml new file mode 100644 index 000000000..d3e57d096 --- /dev/null +++ b/pkg/k2/constructs/template/templates/aws/api/bindings/to_klotho.aws.Container.yaml @@ -0,0 +1,72 @@ +from: klotho.aws.API +to: klotho.aws.Container + +inputs: + Routes: + name: Routes + description: The routes to use + type: list(map) + required: true + min_length: 1 + properties: + Path: + name: Path + description: The path to use + type: string + default: /* + Proxy: + name: Proxy + description: Add a proxy route for child resources + type: bool + default: false + Method: + name: Method + description: The method to use + type: string + default: ANY + +input_rules: + - for_each: '{{ .Select "inputs.Routes" }}' + prefix: '{{ toLower (replace `[^\w]` "-" .Selected.Value.Path) }}' + do: + rules: + - if: '{{ .Selected.Value.Proxy }}' + then: + resources: + ProxyMethod: + type: aws:api_method + namespace: ${from.resources:RestAPI.Name} + name: '{{ .Prefix }}-proxy-method' + properties: + HttpMethod: '{{ toUpper .Selected.Value.Method }}' + ProxyIntegration: + type: aws:api_integration + namespace: ${from.inputs:Name} + name: '{{ .Prefix }}-proxy-integration' + properties: + Route: '{{ trimSuffix .Selected.Value.Path "/" }}/{proxy+}' + Method: '${resources:[{{.Prefix}}.ProxyMethod]}' + edges: + - from: ${from.resources:RestAPI} + to: '{{ .Prefix }}.ProxyIntegration' + - from: '{{ .Prefix }}.ProxyIntegration' + to: ${to.resources:APILoadBalancer} + resources: + Method: + type: aws:api_method + namespace: ${from.resources:RestAPI.Name} + name: '{{ .Prefix }}-{{ toLower (replace `[^\w]` "-" .Selected.Value.Method) }}-method' + properties: + HttpMethod: '{{ toUpper .Selected.Value.Method }}' + Integration: + type: aws:api_integration + namespace: ${from.inputs:Name} + name: '{{ .Prefix }}-{{ toLower (replace `[^\w]` "-" .Selected.Value.Method) }}-integration' + properties: + Route: '{{ .Selected.Value.Path }}' + Method: '${resources:[{{.Prefix}}.Method]}' + edges: + - from: ${from.resources:RestAPI} + to: '{{ .Prefix }}.Integration' + - from: '{{ .Prefix }}.Integration' + to: ${to.resources:APILoadBalancer} \ No newline at end of file diff --git a/pkg/k2/constructs/templates/aws/api/bindings/to_klotho.aws.Function.yaml b/pkg/k2/constructs/template/templates/aws/api/bindings/to_klotho.aws.Function.yaml similarity index 95% rename from pkg/k2/constructs/templates/aws/api/bindings/to_klotho.aws.Function.yaml rename to pkg/k2/constructs/template/templates/aws/api/bindings/to_klotho.aws.Function.yaml index 074d26922..c8a1ea4ce 100644 --- a/pkg/k2/constructs/templates/aws/api/bindings/to_klotho.aws.Function.yaml +++ b/pkg/k2/constructs/template/templates/aws/api/bindings/to_klotho.aws.Function.yaml @@ -6,8 +6,7 @@ inputs: name: Path description: The path to use type: string - default: /* - + default_value: / resources: Integration: type: aws:api_integration diff --git a/pkg/k2/constructs/templates/aws/bucket/bucket.yaml b/pkg/k2/constructs/template/templates/aws/bucket/bucket.yaml similarity index 79% rename from pkg/k2/constructs/templates/aws/bucket/bucket.yaml rename to pkg/k2/constructs/template/templates/aws/bucket/bucket.yaml index 2f8fbe5d0..e091b1e78 100644 --- a/pkg/k2/constructs/templates/aws/bucket/bucket.yaml +++ b/pkg/k2/constructs/template/templates/aws/bucket/bucket.yaml @@ -7,14 +7,13 @@ resources: name: ${inputs:Name} properties: ForceDestroy: ${inputs:ForceDestroy} - IndexDocument: ${inputs:IndexDocument} SSEAlgorithm: ${inputs:SSEAlgorithm} inputs: ForceDestroy: name: ForceDestroy description: Whether to forcibly delete the S3 bucket and all objects it contains during destruction type: bool - default: true + default_value: true IndexDocument: name: IndexDocument description: The webpage that Amazon S3 returns when it receives a request to the root domain name of the bucket or when an index document is specified @@ -23,7 +22,7 @@ inputs: name: SSEAlgorithm description: The server-side encryption algorithm to use to encrypt data stored in the S3 bucket type: string - default: aws:kms + default_value: aws:kms outputs: Bucket: description: The name of the S3 bucket @@ -33,4 +32,12 @@ outputs: value: ${resources:Bucket#Arn} BucketRegionalDomainName: description: The regional domain name of the S3 bucket - value: ${resources:Bucket#BucketRegionalDomainName} \ No newline at end of file + value: ${resources:Bucket#BucketRegionalDomainName} + +input_rules: + - if: '{{ .Inputs.IndexDocument }}' + then: + resources: + Bucket: + properties: + IndexDocument: ${inputs:IndexDocument} diff --git a/pkg/k2/constructs/templates/aws/container/bindings/from_klotho.aws.Api.yaml b/pkg/k2/constructs/template/templates/aws/container/bindings/from_klotho.aws.Api.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/container/bindings/from_klotho.aws.Api.yaml rename to pkg/k2/constructs/template/templates/aws/container/bindings/from_klotho.aws.Api.yaml diff --git a/pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.Bucket.yaml b/pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.Bucket.yaml similarity index 90% rename from pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.Bucket.yaml rename to pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.Bucket.yaml index 485fd0644..b54e8299a 100644 --- a/pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.Bucket.yaml +++ b/pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.Bucket.yaml @@ -5,8 +5,8 @@ inputs: ReadOnly: name: Read Only description: Whether the connection should be read only - type: boolean - default: false + type: bool + default_value: false resources: TaskDefinition: properties: @@ -15,7 +15,7 @@ resources: - Name: ${to.inputs:Name}_BUCKET_ENDPOINT Value: ${to.outputs:BucketRegionalDomainName} input_rules: - - if: '{{ inputs "ReadOnly" }}' + - if: '{{ .Inputs.ReadOnly }}' then: edges: - from: ${from.resources:Service} diff --git a/pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.DynamoDB.yaml b/pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.DynamoDB.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.DynamoDB.yaml rename to pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.DynamoDB.yaml diff --git a/pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.LoadBalancer.yaml b/pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.LoadBalancer.yaml similarity index 84% rename from pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.LoadBalancer.yaml rename to pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.LoadBalancer.yaml index 2c41944a9..cc460baf1 100644 --- a/pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.LoadBalancer.yaml +++ b/pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.LoadBalancer.yaml @@ -4,11 +4,10 @@ inputs: Port: name: Port description: The port to expose on the load balancer - type: number + type: int default: ${from.resources:TaskDefinition.ContainerDefinitions[0].PortMappings[0].HostPort} - validation: - minimum: 1 - maximum: 65535 + minimum: 1 + maximum: 65535 HealthCheck: name: Health Check description: The health check to use for the load balancer @@ -23,25 +22,22 @@ inputs: description: The protocol to use for the health check type: string default: HTTP - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 Path: name: Path description: The path to use for the health check type: string default: / - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 Matcher: name: Matcher description: The matcher to use for the health check type: string default: 200-399 - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 resources: TargetGroup: type: aws:target_group diff --git a/pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.Postgres.yaml b/pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.Postgres.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/container/bindings/to_klotho.aws.Postgres.yaml rename to pkg/k2/constructs/template/templates/aws/container/bindings/to_klotho.aws.Postgres.yaml diff --git a/pkg/k2/constructs/templates/aws/container/container.yaml b/pkg/k2/constructs/template/templates/aws/container/container.yaml similarity index 73% rename from pkg/k2/constructs/templates/aws/container/container.yaml rename to pkg/k2/constructs/template/templates/aws/container/container.yaml index 9ae06f0fd..c7165c874 100644 --- a/pkg/k2/constructs/templates/aws/container/container.yaml +++ b/pkg/k2/constructs/template/templates/aws/container/container.yaml @@ -31,68 +31,69 @@ inputs: Cpu: name: CPU description: The amount of CPU to allocate to the container - type: number - default: 256 - validation: - minimum: 1 - maximum: 4096 + type: int + default_value: 256 + minimum: 1 + maximum: 4096 Context: name: Context description: The context to use to build the container type: path - default: . - validation: - minLength: 1 - maxLength: 63 + default_value: . + min_length: 1 + max_length: 63 Dockerfile: name: Dockerfile description: The Dockerfile to use to build the container type: path - default: Dockerfile - validation: - minLength: 1 - maxLength: 63 + default_value: Dockerfile + min_length: 1 + max_length: 63 EnvironmentVariables: name: EnvironmentVariables description: The environment variables to set in the container - type: KeyValueList - configuration: - keyField: Name + type: key_value_list(string,string) + key_property: + name: Name + type: string + min_length: 1 Image: name: Image description: The image to use for the container type: string - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 EnableExecuteCommand: name: Enable Execute Command description: Whether to enable the execute command functionality for the container - type: boolean - default: false + type: bool + default_value: false Memory: name: Memory description: The amount of memory to allocate to the container - type: number - default: 512 - validation: - minimum: 1 - maximum: 4096 + type: int + default_value: 512 + minimum: 1 + maximum: 4096 Network: name: Network description: The network to deploy the container to - type: Construct(klotho.aws.Network) + type: construct(klotho.aws.Network) Port: name: Port description: The port to expose on the container - type: number - default: 80 - validation: - minimum: 1 - maximum: 65535 + type: int + default_value: 80 + minimum: 1 + maximum: 65535 + HealthCheck: + name: HealthCheck + description: The health check to use for the container + type: map(string, string) + default_value: "CMD-SHELL curl -f http://localhost:${inputs:Port}/ || exit 1" input_rules: - - if: '{{ and (inputs "Dockerfile") (not (inputs "Image")) }}' + - if: '{{ and (.Inputs.Dockerfile) (not (.Inputs.Image)) }}' then: resources: EcrImage: diff --git a/pkg/k2/constructs/templates/aws/dynamodb/dynamodb.yaml b/pkg/k2/constructs/template/templates/aws/dynamodb/dynamodb.yaml similarity index 77% rename from pkg/k2/constructs/templates/aws/dynamodb/dynamodb.yaml rename to pkg/k2/constructs/template/templates/aws/dynamodb/dynamodb.yaml index 4f26f5598..9f99691fc 100644 --- a/pkg/k2/constructs/templates/aws/dynamodb/dynamodb.yaml +++ b/pkg/k2/constructs/template/templates/aws/dynamodb/dynamodb.yaml @@ -11,45 +11,34 @@ resources: HashKey: ${inputs:HashKey} inputs: - Name: - name: Table Name - description: The name of the DynamoDB table - type: string - default: my_table - validation: - minLength: 3 - maxLength: 255 - Attributes: name: Attributes description: List of attribute definitions for the table which includes attribute name and type - type: list - default: + type: list(map) + default_value: - Name: id Type: S properties: Name: type: string description: Name of the attribute - validation: - minLength: 1 - maxLength: 255 + min_length: 1 + max_length: 255 Type: type: string description: The data type for the attribute, such as String (S) or Number (N) - allowedValues: + allowed_values: - S - N - B - validation: - minItems: 1 + min_items: 1 BillingMode: name: Billing Mode description: The billing mode that determines how you are charged for read and write throughput and how you manage capacity type: string - default: PAY_PER_REQUEST - allowedValues: + default_value: PAY_PER_REQUEST + allowed_values: - PROVISIONED - PAY_PER_REQUEST @@ -57,18 +46,16 @@ inputs: name: Hash Key description: The table hash key, which is the partition key for the DynamoDB table type: string - default: id - validation: - minLength: 1 - maxLength: 255 + default_value: id + min_length: 1 + max_length: 255 RangeKey: name: Range Key description: The table range key, which is the sort key for the DynamoDB table type: string - validation: - minLength: 1 - maxLength: 255 + min_length: 1 + max_length: 255 GlobalSecondaryIndexes: name: Global Secondary Indexes @@ -97,27 +84,27 @@ outputs: value: ${resources:DynamoDBTable#Arn} input_rules: - - if: '{{ (inputs "RangeKey") }}' + - if: '{{ .Inputs.RangeKey }}' then: resources: DynamoDBTable: properties: RangeKey: ${inputs:RangeKey} - - if: '{{ not (eq (inputs "Tags") nil) }}' + - if: '{{ not (eq .Inputs.Tags nil) }}' then: resources: DynamoDBTable: properties: Tags: ${inputs:Tags} - - if: '{{ not (eq (inputs "GlobalSecondaryIndexes") nil) }}' + - if: '{{ not (eq .Inputs.GlobalSecondaryIndexes nil) }}' then: resources: DynamoDBTable: properties: GlobalSecondaryIndexes: ${inputs:GlobalSecondaryIndexes} - - if: '{{ not (eq (inputs "LocalSecondaryIndexes") nil) }}' + - if: '{{ not (eq .Inputs.LocalSecondaryIndexes nil) }}' then: resources: DynamoDBTable: diff --git a/pkg/k2/constructs/templates/aws/fastapi/bindings/to_klotho.aws.Bucket.yaml b/pkg/k2/constructs/template/templates/aws/fastapi/bindings/to_klotho.aws.Bucket.yaml similarity index 93% rename from pkg/k2/constructs/templates/aws/fastapi/bindings/to_klotho.aws.Bucket.yaml rename to pkg/k2/constructs/template/templates/aws/fastapi/bindings/to_klotho.aws.Bucket.yaml index e8cc3d0a5..b52ea3118 100644 --- a/pkg/k2/constructs/templates/aws/fastapi/bindings/to_klotho.aws.Bucket.yaml +++ b/pkg/k2/constructs/template/templates/aws/fastapi/bindings/to_klotho.aws.Bucket.yaml @@ -5,7 +5,7 @@ inputs: ReadOnly: name: Read Only description: Whether the connection should be read only - type: boolean + type: bool default: false resources: TaskDefinition: @@ -15,7 +15,7 @@ resources: - Name: ${to.inputs:Name}_BUCKET_ENDPOINT Value: ${to.outputs:BucketRegionalDomainName} input_rules: - - if: '{{ inputs "ReadOnly" }}' + - if: '{{ .Inputs.ReadOnly }}' then: edges: - from: ${from.resources:Service} diff --git a/pkg/k2/constructs/templates/aws/fastapi/bindings/to_klotho.aws.DynamoDB.yaml b/pkg/k2/constructs/template/templates/aws/fastapi/bindings/to_klotho.aws.DynamoDB.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/fastapi/bindings/to_klotho.aws.DynamoDB.yaml rename to pkg/k2/constructs/template/templates/aws/fastapi/bindings/to_klotho.aws.DynamoDB.yaml diff --git a/pkg/k2/constructs/templates/aws/fastapi/bindings/to_klotho.aws.Postgres.yaml b/pkg/k2/constructs/template/templates/aws/fastapi/bindings/to_klotho.aws.Postgres.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/fastapi/bindings/to_klotho.aws.Postgres.yaml rename to pkg/k2/constructs/template/templates/aws/fastapi/bindings/to_klotho.aws.Postgres.yaml diff --git a/pkg/k2/constructs/templates/aws/fastapi/fastapi.yaml b/pkg/k2/constructs/template/templates/aws/fastapi/fastapi.yaml similarity index 86% rename from pkg/k2/constructs/templates/aws/fastapi/fastapi.yaml rename to pkg/k2/constructs/template/templates/aws/fastapi/fastapi.yaml index f850b8891..af288886d 100644 --- a/pkg/k2/constructs/templates/aws/fastapi/fastapi.yaml +++ b/pkg/k2/constructs/template/templates/aws/fastapi/fastapi.yaml @@ -63,27 +63,24 @@ inputs: Cpu: name: CPU description: The amount of CPU to allocate to the container - type: number + type: int default: 256 - validation: - minimum: 1 - maximum: 4096 + minimum: 1 + maximum: 4096 Context: name: Context description: The context to use to build the container type: path default: . - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 Dockerfile: name: Dockerfile description: The Dockerfile to use to build the container type: path default: Dockerfile - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 EnvironmentVariables: name: EnvironmentVariables description: The environment variables to set in the container @@ -94,22 +91,20 @@ inputs: name: Image description: The image to use for the container type: string - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 EnableExecuteCommand: name: Enable Execute Command description: Whether to enable the execute command functionality for the container - type: boolean + type: bool default: false Memory: name: Memory description: The amount of memory to allocate to the container - type: number + type: int default: 512 - validation: - minimum: 1 - maximum: 4096 + minimum: 1 + maximum: 4096 Network: name: Network description: The network to deploy the container to @@ -117,11 +112,10 @@ inputs: Port: name: Port description: The port to expose on the container - type: number + type: int default: 80 - validation: - minimum: 1 - maximum: 65535 + minimum: 1 + maximum: 65535 HealthCheckPath: name: Health Check Path description: The path to use for the health check @@ -135,16 +129,16 @@ inputs: HealthCheckHealthyThreshold: name: Health Check Healthy Threshold description: The number of consecutive successful health checks required before considering the target healthy - type: number + type: int default: 3 HealthCheckUnhealthyThreshold: name: Health Check Unhealthy Threshold description: The number of consecutive failed health checks required before considering the target unhealthy - type: number + type: int default: 3 input_rules: - - if: '{{ and (inputs "Dockerfile") (not (inputs "Image")) }}' + - if: '{{ and .Inputs.Dockerfile (not .Inputs.Image) }}' then: resources: EcrImage: @@ -162,35 +156,35 @@ input_rules: properties: ContainerDefinitions[0].Image: ${inputs:Image} - - if: '{{or (inputs "HealthCheckPath") (inputs "HealthCheckMatcher")}}' + - if: '{{ or .Inputs.HealthCheckPath .Inputs.HealthCheckMatcher }}' then: resources: TargetGroup: properties: HealthCheck.Protocol: HTTP - - if: '{{ (inputs "HealthCheckPath")}}' + - if: '{{ .Inputs.HealthCheckPath }}' then: resources: TargetGroup: properties: HealthCheck.Path: ${inputs:HealthCheckPath} - - if: '{{ (inputs "HealthCheckMatcher")}}' + - if: '{{ .Inputs.HealthCheckMatcher }}' then: resources: TargetGroup: properties: HealthCheck.Matcher: ${inputs:HealthCheckMatcher} - - if: '{{ (inputs "HealthCheckHealthyThreshold")}}' + - if: '{{ .Inputs.HealthCheckHealthyThreshold }}' then: resources: TargetGroup: properties: HealthCheck.HealthyThreshold: ${inputs:HealthCheckHealthyThreshold} - - if: '{{ (inputs "HealthCheckUnhealthyThreshold")}}' + - if: '{{ .Inputs.HealthCheckUnhealthyThreshold }}' then: resources: TargetGroup: diff --git a/pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml b/pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml new file mode 100644 index 000000000..d42836b97 --- /dev/null +++ b/pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml @@ -0,0 +1,15 @@ +from: klotho.aws.Function +to: klotho.aws.Bucket + +inputs: + ReadOnly: + name: Read Only + description: Whether the connection should be read only + type: bool + default_value: false + +edges: + - from: ${from.resources:LambdaFunction} + to: ${to.resources:Bucket} + data: + connection_type: "{{ if .Inputs.ReadOnly }}readonly{{ end }}" diff --git a/pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.DynamoDB.yaml b/pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.DynamoDB.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.DynamoDB.yaml rename to pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.DynamoDB.yaml diff --git a/pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.Postgres.yaml b/pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.Postgres.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.Postgres.yaml rename to pkg/k2/constructs/template/templates/aws/function/bindings/to_klotho.aws.Postgres.yaml diff --git a/pkg/k2/constructs/templates/aws/function/function.yaml b/pkg/k2/constructs/template/templates/aws/function/function.yaml similarity index 50% rename from pkg/k2/constructs/templates/aws/function/function.yaml rename to pkg/k2/constructs/template/templates/aws/function/function.yaml index 289beb75e..0d9b47b8b 100644 --- a/pkg/k2/constructs/templates/aws/function/function.yaml +++ b/pkg/k2/constructs/template/templates/aws/function/function.yaml @@ -7,8 +7,6 @@ resources: type: aws:lambda_function name: ${inputs:Name}-function properties: - Handler: ${inputs:Handler} - Runtime: ${inputs:Runtime} Timeout: ${inputs:Timeout} MemorySize: ${inputs:MemorySize} EnvironmentVariables: ${inputs:EnvironmentVariables} @@ -18,55 +16,51 @@ inputs: name: Handler description: The function entrypoint in your code (not applicable for container images) type: string - validation: - minLength: 1 - maxLength: 128 + min_length: 1 + max_length: 128 Runtime: name: Runtime description: The runtime environment for the Lambda function (not applicable for container images) type: string - default: nodejs14.x - validation: - allowedValues: - - nodejs20.x - - nodejs18.x - - nodejs16.x - - python3.12 - - python3.11 - - python3.10 - - python3.9 - - python3.8 - - java21 - - java17 - - java11 - - java8.al2 - - dotnet8 - - dotnet6 - - ruby3.3 - - ruby3.2 - - provided.al2023 - - provided.al2 + default_value: nodejs14.x + allowed_values: + - nodejs20.x + - nodejs18.x + - nodejs16.x + - python3.12 + - python3.11 + - python3.10 + - python3.9 + - python3.8 + - java21 + - java17 + - java11 + - java8.al2 + - dotnet8 + - dotnet6 + - ruby3.3 + - ruby3.2 + - provided.al2023 + - provided.al2 Timeout: name: Timeout description: The amount of time that Lambda allows a function to run before stopping it - type: number - default: 3 - validation: - minimum: 1 - maximum: 900 + type: int + default_value: 3 + minimum: 1 + maximum: 900 MemorySize: name: Memory Size description: The amount of memory available to the function at runtime - type: number - default: 128 - validation: - minimum: 128 - maximum: 10240 + type: int + default_value: 128 + minimum: 128 + maximum: 10240 EnvironmentVariables: name: Environment Variables description: Environment variables that are accessible from function code during execution type: map - default: {} + default_value: {} Code: name: Code description: The source code of your Lambda function (local path) @@ -97,64 +91,67 @@ inputs: type: path input_rules: - - if: '{{ inputs "Code" }}' + - if: '{{ .Inputs.Code }}' then: resources: LambdaFunction: properties: Code: ${inputs:Code} - - if: '{{ and (inputs "S3Bucket") (inputs "S3Key") }}' - then: - resources: - LambdaFunction: - properties: - S3Bucket: ${inputs:S3Bucket} - S3Key: ${inputs:S3Key} - - if: '{{ inputs "S3ObjectVersion" }}' - then: - resources: - LambdaFunction: - properties: - S3ObjectVersion: ${inputs:S3ObjectVersion} - - if: '{{ inputs "ImageUri" }}' - then: - resources: - LambdaFunction: - properties: - ImageUri: ${inputs:ImageUri} - - if: '{{ inputs "Dockerfile" }}' - then: - resources: - Image: - type: aws:ecr_image - name: ${inputs:Name}-image - properties: - Dockerfile: ${inputs:Dockerfile} - LambdaFunction: - properties: - Image: aws:ecr_image:${inputs:Name}-image#ImageName - - if: '{{ inputs "DockerContext" }}' - then: - resources: - Image: - type: aws:ecr_image - name: ${inputs:Name}-image - properties: - Context: ${inputs:DockerContext} - LambdaFunction: - properties: - Image: aws:ecr_image:${inputs:Name}-image#ImageName - - if: '{{ or (inputs "ImageUri") (inputs "Dockerfile") }}' - then: - resources: - LambdaFunction: - properties: - PackageType: Image - else: - resources: - LambdaFunction: - properties: + Handler: ${inputs:Handler} + Runtime: ${inputs:Runtime} PackageType: Zip + rules: + - if: '{{ and .Inputs.S3Bucket .Inputs.S3Key }}' + then: + resources: + LambdaFunction: + properties: + S3Bucket: ${inputs:S3Bucket} + S3Key: ${inputs:S3Key} + rules: + - if: '{{ .Inputs.S3ObjectVersion }}' + then: + resources: + LambdaFunction: + properties: + S3ObjectVersion: ${inputs:S3ObjectVersion} + else: + rules: + - if: '{{ .Inputs.ImageUri }}' + then: + resources: + LambdaFunction: + properties: + ImageUri: ${inputs:ImageUri} + - if: '{{ .Inputs.Dockerfile }}' + then: + resources: + Image: + type: aws:ecr_image + name: ${inputs:Name}-image + properties: + Dockerfile: ${inputs:Dockerfile} + LambdaFunction: + properties: + Image: aws:ecr_image:${inputs:Name}-image#ImageName + - if: '{{ .Inputs.DockerContext }}' + then: + resources: + Image: + type: aws:ecr_image + name: ${inputs:Name}-image + properties: + Context: ${inputs:DockerContext} + LambdaFunction: + properties: + Image: aws:ecr_image:${inputs:Name}-image#ImageName + - if: '{{ or .Inputs.ImageUri .Inputs.Dockerfile }}' + then: + resources: + LambdaFunction: + properties: + PackageType: Image + outputs: FunctionArn: diff --git a/pkg/k2/constructs/templates/aws/network/network.yaml b/pkg/k2/constructs/template/templates/aws/network/network.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/network/network.yaml rename to pkg/k2/constructs/template/templates/aws/network/network.yaml diff --git a/pkg/k2/constructs/templates/aws/postgres/bindings/from_klotho.aws.Container.yaml b/pkg/k2/constructs/template/templates/aws/postgres/bindings/from_klotho.aws.Container.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/postgres/bindings/from_klotho.aws.Container.yaml rename to pkg/k2/constructs/template/templates/aws/postgres/bindings/from_klotho.aws.Container.yaml diff --git a/pkg/k2/constructs/templates/aws/postgres/bindings/from_klotho.aws.FastAPI.yaml b/pkg/k2/constructs/template/templates/aws/postgres/bindings/from_klotho.aws.FastAPI.yaml similarity index 100% rename from pkg/k2/constructs/templates/aws/postgres/bindings/from_klotho.aws.FastAPI.yaml rename to pkg/k2/constructs/template/templates/aws/postgres/bindings/from_klotho.aws.FastAPI.yaml diff --git a/pkg/k2/constructs/templates/aws/postgres/postgres.yaml b/pkg/k2/constructs/template/templates/aws/postgres/postgres.yaml similarity index 85% rename from pkg/k2/constructs/templates/aws/postgres/postgres.yaml rename to pkg/k2/constructs/template/templates/aws/postgres/postgres.yaml index 5f6b6e75b..9c048559d 100644 --- a/pkg/k2/constructs/templates/aws/postgres/postgres.yaml +++ b/pkg/k2/constructs/template/templates/aws/postgres/postgres.yaml @@ -26,62 +26,55 @@ inputs: description: The instance class for the database instance type: string default: db.t3.micro - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 AllocatedStorage: name: Allocated Storage description: The amount of storage to allocate to the database instance (in GB) - type: number + type: int default: 20 - validation: - minValue: 5 - maxValue: 6144 + min_value: 5 + max_value: 6144 EngineVersion: name: Engine Version description: The version of the Postgres engine to use type: string default: '14.11' - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 Username: name: Master Username description: The master username for the database instance type: string default: admin - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 Password: name: Master User Password description: The master user password for the database instance type: string - validation: - minLength: 8 - maxLength: 128 + min_length: 8 + max_length: 128 DatabaseName: name: Database Name description: The name of the database type: string default: main - validation: - minLength: 1 - maxLength: 63 + min_length: 1 + max_length: 63 Port: name: Port description: The port to expose on the database instance - type: number + type: int default: 5432 - validation: - minValue: 1 - maxValue: 65535 + min_value: 1 + max_value: 65535 Network: name: Network diff --git a/pkg/k2/constructs/templates/aws/api/bindings/to_klotho.aws.Container.yaml b/pkg/k2/constructs/templates/aws/api/bindings/to_klotho.aws.Container.yaml deleted file mode 100644 index 388d7cfab..000000000 --- a/pkg/k2/constructs/templates/aws/api/bindings/to_klotho.aws.Container.yaml +++ /dev/null @@ -1,23 +0,0 @@ -from: klotho.aws.API -to: klotho.aws.Container - -inputs: - Path: - name: Path - description: The path to use - type: string - default: /* - -resources: - Integration: - type: aws:api_integration - namespace: ${from.inputs:Name} - name: ${to.inputs:Name} - properties: - Route: ${inputs:Path} - -edges: - - from: ${from.resources:RestAPI} - to: ${resources:Integration} - - from: ${resources:Integration} - to: ${to.resources:APILoadBalancer} diff --git a/pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml b/pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml deleted file mode 100644 index 7e2863dd1..000000000 --- a/pkg/k2/constructs/templates/aws/function/bindings/to_klotho.aws.Bucket.yaml +++ /dev/null @@ -1,21 +0,0 @@ -from: klotho.aws.Function -to: klotho.aws.Bucket - -inputs: - ReadOnly: - name: Read Only - description: Whether the connection should be read only - type: boolean - default: false -input_rules: - - if: '{{ inputs "ReadOnly" }}' - then: - edges: - - from: ${from.resources:LambdaFunction} - to: ${to.resources:Bucket} - data: - connection_type: readonly - else: - edges: - - from: ${from.resources:LambdaFunction} - to: ${to.resources:Bucket} diff --git a/pkg/k2/ir_samples/container.yaml b/pkg/k2/ir_samples/container.yaml index 9d2a69c7b..73e953c91 100644 --- a/pkg/k2/ir_samples/container.yaml +++ b/pkg/k2/ir_samples/container.yaml @@ -16,10 +16,10 @@ constructs: value: Dockerfile status: resolved Cpu: - type: number + type: int value: 256 Memory: - type: number + type: int value: 512 source_hash: type: string diff --git a/pkg/k2/ir_samples/testenv.yaml b/pkg/k2/ir_samples/testenv.yaml index 198618ebb..da7ed2791 100644 --- a/pkg/k2/ir_samples/testenv.yaml +++ b/pkg/k2/ir_samples/testenv.yaml @@ -18,6 +18,6 @@ constructs: value: nginx:latest status: resolved port: - type: number + type: int value: 80 status: resolved diff --git a/pkg/k2/k2_test.go b/pkg/k2/k2_test.go index fae5ca34d..5f1da79c7 100644 --- a/pkg/k2/k2_test.go +++ b/pkg/k2/k2_test.go @@ -230,14 +230,17 @@ func (tc testCase) assertConstructFileEquals(t *testing.T, construct, file strin } func assertYamlEquals(t *testing.T, file string, expectedF, actualF io.Reader) bool { + var expectedB, actualB []byte + expectedF.Read(expectedB) + actualF.Read(actualB) var expect, actual map[string]interface{} - err := yaml.NewDecoder(expectedF).Decode(&expect) + err := yaml.Unmarshal(expectedB, &expect) if err != nil { t.Errorf("failed to read expected yaml %s: %v", file, err) return false } - err = yaml.NewDecoder(actualF).Decode(&actual) + err = yaml.Unmarshal(actualB, &actual) if err != nil { t.Errorf("failed to read actual yaml %s: %v", file, err) return false @@ -249,8 +252,11 @@ func assertYamlEquals(t *testing.T, file string, expectedF, actualF io.Reader) b } changes, err := differ.Diff(expect, actual) if err != nil { - t.Errorf("failed to diff %s: %v", file, err) - return false + changes, err = differ.Diff(string(expectedB), string(actualB)) + if err != nil { + t.Errorf("failed to diff %s: %v", file, err) + return false + } } for _, c := range changes { path := strings.Join(c.Path, ".") diff --git a/pkg/k2/language_host/language_host.go b/pkg/k2/language_host/language_host.go index e540887ee..b533fc104 100644 --- a/pkg/k2/language_host/language_host.go +++ b/pkg/k2/language_host/language_host.go @@ -79,7 +79,6 @@ func copyToTempDir(name, content string) (string, error) { return "", fmt.Errorf("failed to write to temp file: %w", err) } return f.Name(), nil - } func StartPythonClient(ctx context.Context, debugConfig DebugConfig, pythonPath string) (*exec.Cmd, *ServerState, error) { diff --git a/pkg/k2/language_host/python/klothosdk/src/klotho/aws/api.py b/pkg/k2/language_host/python/klothosdk/src/klotho/aws/api.py index 101f0535b..ccad0c42a 100644 --- a/pkg/k2/language_host/python/klothosdk/src/klotho/aws/api.py +++ b/pkg/k2/language_host/python/klothosdk/src/klotho/aws/api.py @@ -3,11 +3,34 @@ from klotho.construct import Binding, Construct, ConstructOptions, add_binding +class RouteArgs: + def __init__(self, path: str, method: str = "ANY", proxy: bool = False): + self.path = path + self.method = method + self.proxy = proxy + + class Api(Construct): def __init__(self, name: str, opts: Optional[ConstructOptions] = None): super().__init__( name, construct_type="klotho.aws.Api", properties={}, opts=opts ) - def route_to(self, path: str, dest: Construct): - add_binding(self, Binding(dest, {"Path": path})) + def route(self, routes: list[RouteArgs], destination: Construct): + + add_binding( + self, + Binding( + destination, + { + "Routes": [ + { + "Path": route.path, + "Method": route.method, + "Proxy": route.proxy, + } + for route in routes + ] + }, + ), + ) diff --git a/pkg/k2/language_host/python/klothosdk/src/klotho/construct.py b/pkg/k2/language_host/python/klothosdk/src/klotho/construct.py index 53b6f18b2..5ecfd58cd 100644 --- a/pkg/k2/language_host/python/klothosdk/src/klotho/construct.py +++ b/pkg/k2/language_host/python/klothosdk/src/klotho/construct.py @@ -195,3 +195,9 @@ def add_binding(source: Construct, binding: BindingType): except ValueError: # Ignore non-URN dependencies. pass + +def get_binding(source: Construct, to: URN) -> Optional[Binding]: + for binding in source.bindings: + if binding.to == to: + return binding + return None \ No newline at end of file diff --git a/pkg/k2/language_host/python/klothosdk/src/klotho/output.py b/pkg/k2/language_host/python/klothosdk/src/klotho/output.py index 3c52912d8..b6b16708f 100644 --- a/pkg/k2/language_host/python/klothosdk/src/klotho/output.py +++ b/pkg/k2/language_host/python/klothosdk/src/klotho/output.py @@ -155,7 +155,7 @@ def run(*values: str) -> str: return cls.all(inputs, run) @staticmethod - def from_mapping(input: Input[Mapping]) -> "Output[Mapping]": + def from_mapping(input: Input[Mapping]) -> "Input[Mapping]": if isinstance(input, Output): return input if isinstance(input, Mapping): @@ -169,6 +169,9 @@ def from_mapping(input: Input[Mapping]) -> "Output[Mapping]": else: resolved_mappings[key] = value + if not unresolved_mappings: + return input + def callback(resolved_outputs: Mapping) -> Mapping: result = {**resolved_mappings, **resolved_outputs} return result diff --git a/pkg/k2/language_host/python/samples/starter/infra-api.py b/pkg/k2/language_host/python/samples/starter/infra-api.py index f010c52c9..a41357e46 100644 --- a/pkg/k2/language_host/python/samples/starter/infra-api.py +++ b/pkg/k2/language_host/python/samples/starter/infra-api.py @@ -3,6 +3,7 @@ import klotho import klotho.aws as aws +from klotho.aws.api import RouteArgs app = klotho.Application( "api", @@ -18,4 +19,11 @@ ) api = aws.Api("my-api") -api.route_to("/", container) +my_api = aws.Api("my-api") +my_api.route( + [ + RouteArgs(path="/", method="GET", proxy=True), + ], + container, +) + diff --git a/pkg/k2/model/urn.go b/pkg/k2/model/urn.go index 5361b9bdd..094da7ffe 100644 --- a/pkg/k2/model/urn.go +++ b/pkg/k2/model/urn.go @@ -255,3 +255,7 @@ func (u *URN) Compare(other URN) int { } return 0 } + +func (u *URN) IsZero() bool { + return u == nil || *u == URN{} +} diff --git a/pkg/k2/orchestration/up_orchestrator.go b/pkg/k2/orchestration/up_orchestrator.go index 1de6e1d2d..a30028efb 100644 --- a/pkg/k2/orchestration/up_orchestrator.go +++ b/pkg/k2/orchestration/up_orchestrator.go @@ -223,6 +223,10 @@ func (uo *UpOrchestrator) executeAction(ctx context.Context, c model.ConstructSt case model.DryRunPreview: _, err = stack.RunPreview(ctx, uo.FS, stackRef) uo.placeholderOutputs(ctx, *c.URN) + if err != nil { + return fmt.Errorf("error running pulumi preview command: %w", err) + } + err = sm.RegisterOutputValues(ctx, stackRef.ConstructURN, map[string]any{}) return err case model.DryRunCompile: @@ -245,12 +249,11 @@ func (uo *UpOrchestrator) executeAction(ctx context.Context, c model.ConstructSt if err != nil { return fmt.Errorf("error running tsc: %w", err) } - return nil - + return sm.RegisterOutputValues(ctx, stackRef.ConstructURN, map[string]any{}) case model.DryRunFileOnly: // file already written, nothing left to do uo.placeholderOutputs(ctx, *c.URN) - return nil + return sm.RegisterOutputValues(ctx, stackRef.ConstructURN, map[string]any{}) } // Run pulumi up command for the construct diff --git a/pkg/k2/testdata/bucket_ro/my-bucket.engine_input.yaml b/pkg/k2/testdata/bucket_ro/my-bucket.engine_input.yaml index c6fb00ea1..2595dda6a 100644 --- a/pkg/k2/testdata/bucket_ro/my-bucket.engine_input.yaml +++ b/pkg/k2/testdata/bucket_ro/my-bucket.engine_input.yaml @@ -7,11 +7,6 @@ constraints: target: aws:s3_bucket:my-bucket property: ForceDestroy value: true - - scope: resource - operator: equals - target: aws:s3_bucket:my-bucket - property: IndexDocument - value: null - scope: resource operator: equals target: aws:s3_bucket:my-bucket diff --git a/pkg/k2/testdata/bucket_ro/my-bucket.resources.yaml b/pkg/k2/testdata/bucket_ro/my-bucket.resources.yaml index 4e9a4d2c2..9b2ef85cf 100644 --- a/pkg/k2/testdata/bucket_ro/my-bucket.resources.yaml +++ b/pkg/k2/testdata/bucket_ro/my-bucket.resources.yaml @@ -1,7 +1,6 @@ resources: aws:s3_bucket:my-bucket: ForceDestroy: true - IndexDocument: "" SSEAlgorithm: aws:kms Tags: GLOBAL_KLOTHO_TAG: k2 diff --git a/pkg/k2/testdata/bucket_ro/my-container.engine_input.yaml b/pkg/k2/testdata/bucket_ro/my-container.engine_input.yaml index bfba0a803..a9d65bc4c 100644 --- a/pkg/k2/testdata/bucket_ro/my-container.engine_input.yaml +++ b/pkg/k2/testdata/bucket_ro/my-container.engine_input.yaml @@ -85,7 +85,6 @@ resources: aws:s3_bucket:my-bucket: ForceDestroy: true Id: preview(id=aws:s3_bucket:my-bucket) - IndexDocument: "" SSEAlgorithm: aws:kms Tags: GLOBAL_KLOTHO_TAG: k2 diff --git a/pkg/k2/testdata/bucket_ro/my-container.resources.yaml b/pkg/k2/testdata/bucket_ro/my-container.resources.yaml index 07394d19b..91f058285 100644 --- a/pkg/k2/testdata/bucket_ro/my-container.resources.yaml +++ b/pkg/k2/testdata/bucket_ro/my-container.resources.yaml @@ -262,7 +262,6 @@ resources: aws:s3_bucket:my-bucket: ForceDestroy: true Id: preview(id=aws:s3_bucket:my-bucket) - IndexDocument: "" SSEAlgorithm: aws:kms Tags: GLOBAL_KLOTHO_TAG: k2 diff --git a/pkg/k2/testdata/dynamo/my-dynamodb.resources.yaml b/pkg/k2/testdata/dynamo/my-dynamodb.resources.yaml index bde5b11fd..030757620 100644 --- a/pkg/k2/testdata/dynamo/my-dynamodb.resources.yaml +++ b/pkg/k2/testdata/dynamo/my-dynamodb.resources.yaml @@ -25,7 +25,7 @@ resources: RESOURCE_NAME: my-dynamodb edges: outputs: - TableName: - ref: aws:dynamodb_table:my-dynamodb#Name TableArn: ref: aws:dynamodb_table:my-dynamodb#Arn + TableName: + ref: aws:dynamodb_table:my-dynamodb#Name diff --git a/pkg/k2/testdata/function/docker-func.engine_input.yaml b/pkg/k2/testdata/function/docker-func.engine_input.yaml index 87ef6b242..b6fec6288 100644 --- a/pkg/k2/testdata/function/docker-func.engine_input.yaml +++ b/pkg/k2/testdata/function/docker-func.engine_input.yaml @@ -20,11 +20,6 @@ constraints: target: aws:lambda_function:docker-func-function property: EnvironmentVariables value: {} - - scope: resource - operator: equals - target: aws:lambda_function:docker-func-function - property: Handler - value: null - scope: resource operator: equals target: aws:lambda_function:docker-func-function @@ -40,11 +35,6 @@ constraints: target: aws:lambda_function:docker-func-function property: PackageType value: Image - - scope: resource - operator: equals - target: aws:lambda_function:docker-func-function - property: Runtime - value: nodejs14.x - scope: resource operator: equals target: aws:lambda_function:docker-func-function diff --git a/pkg/k2/testdata/function/docker-func.resources.yaml b/pkg/k2/testdata/function/docker-func.resources.yaml index e0f3981eb..3e6a0ab00 100644 --- a/pkg/k2/testdata/function/docker-func.resources.yaml +++ b/pkg/k2/testdata/function/docker-func.resources.yaml @@ -2,12 +2,11 @@ resources: aws:lambda_function:docker-func-function: EnvironmentVariables: {} ExecutionRole: aws:iam_role:docker-func-function-ExecutionRole - Handler: "" Image: aws:ecr_image:docker-func-image#ImageName LogConfig: Format: Text MemorySize: 128 - Runtime: nodejs14.x + Runtime: nodejs20.x Tags: GLOBAL_KLOTHO_TAG: k2 RESOURCE_NAME: docker-func-function diff --git a/pkg/k2/testdata/function/infra.py b/pkg/k2/testdata/function/infra.py index 4e9df6f1b..4be234d2e 100644 --- a/pkg/k2/testdata/function/infra.py +++ b/pkg/k2/testdata/function/infra.py @@ -2,6 +2,7 @@ import klotho import klotho.aws as aws +from klotho.aws.api import RouteArgs klotho.Application( "my-app", @@ -22,4 +23,4 @@ ) api = aws.Api("my-api") -api.route_to("/", docker_func) +api.route([RouteArgs(path="/")], destination=docker_func) diff --git a/pkg/k2/testdata/function/my-api.engine_input.yaml b/pkg/k2/testdata/function/my-api.engine_input.yaml index 33aa81ab0..409cb67aa 100644 --- a/pkg/k2/testdata/function/my-api.engine_input.yaml +++ b/pkg/k2/testdata/function/my-api.engine_input.yaml @@ -44,11 +44,10 @@ resources: aws:lambda_function:docker-func-function: EnvironmentVariables: {} FunctionName: preview(id=aws:lambda_function:docker-func-function) - Handler: "" LogConfig: Format: Text MemorySize: 128 - Runtime: nodejs14.x + Runtime: nodejs20.x Tags: GLOBAL_KLOTHO_TAG: k2 RESOURCE_NAME: docker-func-function diff --git a/pkg/k2/testdata/function/my-api.resources.yaml b/pkg/k2/testdata/function/my-api.resources.yaml index aace16ec9..cd137e549 100644 --- a/pkg/k2/testdata/function/my-api.resources.yaml +++ b/pkg/k2/testdata/function/my-api.resources.yaml @@ -40,11 +40,10 @@ resources: aws:lambda_function:docker-func-function: EnvironmentVariables: {} FunctionName: preview(id=aws:lambda_function:docker-func-function) - Handler: "" LogConfig: Format: Text MemorySize: 128 - Runtime: nodejs14.x + Runtime: nodejs20.x Tags: GLOBAL_KLOTHO_TAG: k2 RESOURCE_NAME: docker-func-function diff --git a/pkg/k2/testdata/function/my-bucket.engine_input.yaml b/pkg/k2/testdata/function/my-bucket.engine_input.yaml index c6fb00ea1..2595dda6a 100644 --- a/pkg/k2/testdata/function/my-bucket.engine_input.yaml +++ b/pkg/k2/testdata/function/my-bucket.engine_input.yaml @@ -7,11 +7,6 @@ constraints: target: aws:s3_bucket:my-bucket property: ForceDestroy value: true - - scope: resource - operator: equals - target: aws:s3_bucket:my-bucket - property: IndexDocument - value: null - scope: resource operator: equals target: aws:s3_bucket:my-bucket diff --git a/pkg/k2/testdata/function/my-bucket.resources.yaml b/pkg/k2/testdata/function/my-bucket.resources.yaml index 4e9a4d2c2..9b2ef85cf 100644 --- a/pkg/k2/testdata/function/my-bucket.resources.yaml +++ b/pkg/k2/testdata/function/my-bucket.resources.yaml @@ -1,7 +1,6 @@ resources: aws:s3_bucket:my-bucket: ForceDestroy: true - IndexDocument: "" SSEAlgorithm: aws:kms Tags: GLOBAL_KLOTHO_TAG: k2 diff --git a/pkg/k2/testdata/function/zip-func.engine_input.yaml b/pkg/k2/testdata/function/zip-func.engine_input.yaml index da982582f..c542d3597 100644 --- a/pkg/k2/testdata/function/zip-func.engine_input.yaml +++ b/pkg/k2/testdata/function/zip-func.engine_input.yaml @@ -58,7 +58,6 @@ resources: aws:s3_bucket:my-bucket: ForceDestroy: true Id: preview(id=aws:s3_bucket:my-bucket) - IndexDocument: "" SSEAlgorithm: aws:kms Tags: GLOBAL_KLOTHO_TAG: k2 diff --git a/pkg/k2/testdata/function/zip-func.resources.yaml b/pkg/k2/testdata/function/zip-func.resources.yaml index c2553243f..ca1538c69 100644 --- a/pkg/k2/testdata/function/zip-func.resources.yaml +++ b/pkg/k2/testdata/function/zip-func.resources.yaml @@ -51,7 +51,6 @@ resources: aws:s3_bucket:my-bucket: ForceDestroy: true Id: preview(id=aws:s3_bucket:my-bucket) - IndexDocument: "" SSEAlgorithm: aws:kms Tags: GLOBAL_KLOTHO_TAG: k2 diff --git a/pkg/k2/testdata/simple_api/infra.py b/pkg/k2/testdata/simple_api/infra.py index 043b2c650..3bdf80072 100644 --- a/pkg/k2/testdata/simple_api/infra.py +++ b/pkg/k2/testdata/simple_api/infra.py @@ -1,5 +1,6 @@ import klotho import klotho.aws as aws +from klotho.aws.api import RouteArgs app = klotho.Application( "my-app", @@ -15,4 +16,4 @@ ) api = aws.Api("my-api") -api.route_to("/", container) +api.route([RouteArgs(path="/")], destination=container) diff --git a/pkg/k2/testdata/simple_api/my-api.engine_input.yaml b/pkg/k2/testdata/simple_api/my-api.engine_input.yaml index a16fd5597..1382b3767 100644 --- a/pkg/k2/testdata/simple_api/my-api.engine_input.yaml +++ b/pkg/k2/testdata/simple_api/my-api.engine_input.yaml @@ -1,7 +1,10 @@ constraints: - scope: application operator: must_exist - node: aws:api_integration:my-api:my-container + node: aws:api_integration:my-api:--any-integration + - scope: application + operator: must_exist + node: aws:api_method:my-api-api:--any-method - scope: application operator: must_exist node: aws:api_stage:my-api-api:my-api-stage @@ -10,9 +13,19 @@ constraints: node: aws:rest_api:my-api-api - scope: resource operator: equals - target: aws:api_integration:my-api:my-container + target: aws:api_integration:my-api:--any-integration + property: Method + value: aws:api_method:my-api-api:--any-method + - scope: resource + operator: equals + target: aws:api_integration:my-api:--any-integration property: Route value: / + - scope: resource + operator: equals + target: aws:api_method:my-api-api:--any-method + property: HttpMethod + value: ANY - scope: resource operator: equals target: aws:api_stage:my-api-api:my-api-stage @@ -26,14 +39,14 @@ constraints: - scope: edge operator: must_exist target: - source: aws:api_integration:my-api:my-container + source: aws:api_integration:my-api:--any-integration target: aws:load_balancer:api-my-container-lb data: {} - scope: edge operator: must_exist target: source: aws:rest_api:my-api-api - target: aws:api_integration:my-api:my-container + target: aws:api_integration:my-api:--any-integration data: {} - scope: output operator: must_exist diff --git a/pkg/k2/testdata/simple_api/my-api.index.ts b/pkg/k2/testdata/simple_api/my-api.index.ts index 0613e8a8b..3414fd163 100644 --- a/pkg/k2/testdata/simple_api/my-api.index.ts +++ b/pkg/k2/testdata/simple_api/my-api.index.ts @@ -17,8 +17,8 @@ const my_api_api = new aws.apigateway.RestApi("my-api-api", { tags: {GLOBAL_KLOTHO_TAG: "k2", RESOURCE_NAME: "my-api-api"}, }) const default_network_vpc = aws.ec2.Vpc.get("default-network-vpc", "preview(id=aws:vpc:default-network-vpc)") -const my_container_api_method = new aws.apigateway.Method( - "my-container-api_method", +const __any_method = new aws.apigateway.Method( + "--any-method", { restApi: my_api_api.id, resourceId: my_api_api.rootResourceId, @@ -38,34 +38,34 @@ const my_container_tg = aws.lb.TargetGroup.get("my-container-tg", "preview(id=aw const api_my_container_lb = aws.lb.LoadBalancer.get("api-my-container-lb", "preview(id=aws:load_balancer:api-my-container-lb)") export const api_my_container_lb_DomainName = api_my_container_lb.dnsName const my_container_service = aws.ecs.Service.get("my-container-service", "preview(id=aws:ecs_service:my-container-service)".split('/').slice(-2).join('/')) -const my_container_api_my_container_lb = new aws.apigateway.VpcLink("my-container-api-my-container-lb", { +const __any_integration_api_my_container_lb = new aws.apigateway.VpcLink("--any-integration-api-my-container-lb", { targetArn: api_my_container_lb.arn, - tags: {GLOBAL_KLOTHO_TAG: "k2", RESOURCE_NAME: "my-container-api-my-container-lb"}, + tags: {GLOBAL_KLOTHO_TAG: "k2", RESOURCE_NAME: "--any-integration-api-my-container-lb"}, }) -const my_container = new aws.apigateway.Integration( - "my-container", +const __any_integration = new aws.apigateway.Integration( + "--any-integration", { restApi: my_api_api.id, resourceId: my_api_api.rootResourceId, - httpMethod: my_container_api_method.httpMethod, + httpMethod: __any_method.httpMethod, integrationHttpMethod: "ANY", type: "HTTP_PROXY", connectionType: "VPC_LINK", - connectionId: my_container_api_my_container_lb.id, + connectionId: __any_integration_api_my_container_lb.id, uri: pulumi.interpolate`http://${ (api_my_container_lb as aws.lb.LoadBalancer).dnsName }${"/".replace('+', '')}`, }, - { parent: my_container_api_method } + { parent: __any_method } ) const api_deployment_0 = new aws.apigateway.Deployment( "api_deployment-0", { restApi: my_api_api.id, - triggers: {myContainer: "my-container", myContainerApiMethod: "my-container-api_method"}, + triggers: {AnyIntegration: "--any-integration", AnyMethod: "--any-method"}, }, { - dependsOn: [my_api_api, my_container, my_container_api_method], + dependsOn: [__any_integration, __any_method, my_api_api], } ) const my_api_stage = new aws.apigateway.Stage("my-api-stage", { @@ -84,7 +84,7 @@ export const $urns = { "aws:ecs_cluster:ecs_cluster-0": (ecs_cluster_0 as any).urn, "aws:rest_api:my-api-api": (my_api_api as any).urn, "aws:vpc:default-network-vpc": (default_network_vpc as any).urn, - "aws:api_method:my-api-api:my-container-api_method": (my_container_api_method as any).urn, + "aws:api_method:my-api-api:--any-method": (__any_method as any).urn, "aws:security_group:default-network-vpc:my-container-service-security_group": (my_container_service_security_group as any).urn, "aws:subnet:default-network-vpc:default-network-private-subnet-1": (default_network_private_subnet_1 as any).urn, "aws:subnet:default-network-vpc:default-network-private-subnet-2": (default_network_private_subnet_2 as any).urn, @@ -93,8 +93,8 @@ export const $urns = { "aws:target_group:my-container-tg": (my_container_tg as any).urn, "aws:load_balancer:api-my-container-lb": (api_my_container_lb as any).urn, "aws:ecs_service:my-container-service": (my_container_service as any).urn, - "aws:vpc_link:my-container-api-my-container-lb": (my_container_api_my_container_lb as any).urn, - "aws:api_integration:my-api-api:my-container": (my_container as any).urn, + "aws:vpc_link:--any-integration-api-my-container-lb": (__any_integration_api_my_container_lb as any).urn, + "aws:api_integration:my-api-api:--any-integration": (__any_integration as any).urn, "aws:api_deployment:my-api-api:api_deployment-0": (api_deployment_0 as any).urn, "aws:api_stage:my-api-api:my-api-stage": (my_api_stage as any).urn, } diff --git a/pkg/k2/testdata/simple_api/my-api.resources.yaml b/pkg/k2/testdata/simple_api/my-api.resources.yaml index ef8b4507b..9c0c606a4 100644 --- a/pkg/k2/testdata/simple_api/my-api.resources.yaml +++ b/pkg/k2/testdata/simple_api/my-api.resources.yaml @@ -46,8 +46,8 @@ resources: aws:api_deployment:my-api-api:api_deployment-0: RestApi: aws:rest_api:my-api-api Triggers: - my-container: my-container - my-container-api_method: my-container-api_method + --any-integration: --any-integration + --any-method: --any-method aws:rest_api:my-api-api: BinaryMediaTypes: - application/octet-stream @@ -55,26 +55,26 @@ resources: Tags: GLOBAL_KLOTHO_TAG: k2 RESOURCE_NAME: my-api-api - aws:api_method:my-api-api:my-container-api_method: + aws:api_method:my-api-api:--any-method: Authorization: NONE HttpMethod: ANY RequestParameters: {} RestApi: aws:rest_api:my-api-api - aws:api_integration:my-api-api:my-container: + aws:api_integration:my-api-api:--any-integration: ConnectionType: VPC_LINK IntegrationHttpMethod: ANY - Method: aws:api_method:my-api-api:my-container-api_method + Method: aws:api_method:my-api-api:--any-method RequestParameters: {} RestApi: aws:rest_api:my-api-api Route: / Target: aws:load_balancer:api-my-container-lb Type: HTTP_PROXY - Uri: aws:api_integration:my-api-api:my-container#LbUri - VpcLink: aws:vpc_link:my-container-api-my-container-lb - aws:vpc_link:my-container-api-my-container-lb: + Uri: aws:api_integration:my-api-api:--any-integration#LbUri + VpcLink: aws:vpc_link:--any-integration-api-my-container-lb + aws:vpc_link:--any-integration-api-my-container-lb: Tags: GLOBAL_KLOTHO_TAG: k2 - RESOURCE_NAME: my-container-api-my-container-lb + RESOURCE_NAME: --any-integration-api-my-container-lb Target: aws:load_balancer:api-my-container-lb aws:load_balancer:api-my-container-lb: Id: preview(id=aws:load_balancer:api-my-container-lb) @@ -185,14 +185,14 @@ edges: aws:subnet:default-network-vpc:default-network-public-subnet-1 -> aws:vpc:default-network-vpc: aws:subnet:default-network-vpc:default-network-public-subnet-2 -> aws:vpc:default-network-vpc: aws:target_group:my-container-tg -> aws:ecs_service:my-container-service: - aws:api_deployment:my-api-api:api_deployment-0 -> aws:api_integration:my-api-api:my-container: - aws:api_deployment:my-api-api:api_deployment-0 -> aws:api_method:my-api-api:my-container-api_method: + aws:api_deployment:my-api-api:api_deployment-0 -> aws:api_integration:my-api-api:--any-integration: + aws:api_deployment:my-api-api:api_deployment-0 -> aws:api_method:my-api-api:--any-method: aws:api_deployment:my-api-api:api_deployment-0 -> aws:rest_api:my-api-api: - aws:rest_api:my-api-api -> aws:api_integration:my-api-api:my-container: - aws:rest_api:my-api-api -> aws:api_method:my-api-api:my-container-api_method: - aws:api_method:my-api-api:my-container-api_method -> aws:api_integration:my-api-api:my-container: - aws:api_integration:my-api-api:my-container -> aws:vpc_link:my-container-api-my-container-lb: - aws:vpc_link:my-container-api-my-container-lb -> aws:load_balancer:api-my-container-lb: + aws:rest_api:my-api-api -> aws:api_integration:my-api-api:--any-integration: + aws:rest_api:my-api-api -> aws:api_method:my-api-api:--any-method: + aws:api_method:my-api-api:--any-method -> aws:api_integration:my-api-api:--any-integration: + aws:api_integration:my-api-api:--any-integration -> aws:vpc_link:--any-integration-api-my-container-lb: + aws:vpc_link:--any-integration-api-my-container-lb -> aws:load_balancer:api-my-container-lb: aws:load_balancer:api-my-container-lb -> aws:subnet:default-network-vpc:default-network-private-subnet-1: aws:load_balancer:api-my-container-lb -> aws:subnet:default-network-vpc:default-network-private-subnet-2: aws:ecs_service:my-container-service -> aws:ecs_cluster:ecs_cluster-0: diff --git a/pkg/k2/testdata/simple_api/my-container.engine_input.yaml b/pkg/k2/testdata/simple_api/my-container.engine_input.yaml index 6e4b5afc4..1f9bc7cf0 100644 --- a/pkg/k2/testdata/simple_api/my-container.engine_input.yaml +++ b/pkg/k2/testdata/simple_api/my-container.engine_input.yaml @@ -42,7 +42,7 @@ constraints: target: aws:ecs_service:my-container-service property: LoadBalancers[0] value: - ContainerName: null + ContainerName: my-container ContainerPort: 80 TargetGroup: aws:target_group:my-container-tg - scope: resource diff --git a/pkg/knowledgebase/kb.go b/pkg/knowledgebase/kb.go index 34f00ee65..c1c6bc1a5 100644 --- a/pkg/knowledgebase/kb.go +++ b/pkg/knowledgebase/kb.go @@ -382,7 +382,7 @@ resourceLoop: errs = errors.Join(errs, err) continue } - preXform := path.Get() + preXform, _ := path.Get() if preXform == nil { continue } diff --git a/pkg/reflectutil/map.go b/pkg/reflectutil/map.go new file mode 100644 index 000000000..8f19a5af6 --- /dev/null +++ b/pkg/reflectutil/map.go @@ -0,0 +1,25 @@ +package reflectutil + +import ( + "fmt" + "reflect" +) + +func MapContainsKey(m any, key interface{}) (bool, error) { + var mapValue reflect.Value + if mValue, ok := m.(reflect.Value); ok { + mapValue = mValue + } else { + mapValue = reflect.ValueOf(m) + } + if mapValue.Kind() != reflect.Map { + return false, fmt.Errorf("value is not a map") + } + + keyValue := reflect.ValueOf(key) + if !keyValue.IsValid() { + return false, fmt.Errorf("invalid key") + } + + return mapValue.MapIndex(keyValue).IsValid(), nil +} diff --git a/pkg/k2/reflectutil/reflectutil.go b/pkg/reflectutil/reflectutil.go similarity index 74% rename from pkg/k2/reflectutil/reflectutil.go rename to pkg/reflectutil/reflectutil.go index b8c973824..22c516a59 100644 --- a/pkg/k2/reflectutil/reflectutil.go +++ b/pkg/reflectutil/reflectutil.go @@ -1,6 +1,7 @@ package reflectutil import ( + "errors" "fmt" "reflect" "strconv" @@ -14,7 +15,10 @@ This function is used to get the concrete value of a reflect.Value even if it is Concrete values are values that are not pointers or interfaces (including maps, slices, structs, etc.). */ func GetConcreteValue(v reflect.Value) any { - return GetConcreteElement(v).Interface() + if v.IsValid() { + return GetConcreteElement(v).Interface() + } + return nil } // IsNotConcrete returns true if the reflect.Value is a pointer or interface. @@ -31,18 +35,22 @@ func GetConcreteElement(v reflect.Value) reflect.Value { return v } -// GetField returns the reflect.Value of a field in a struct or map. +// GetField returns the [reflect.Value] of a field in a struct or map. func GetField(v reflect.Value, fieldExpr string) (reflect.Value, error) { if v.Kind() == reflect.Invalid { return reflect.Value{}, fmt.Errorf("value is nil") } - fields := strings.Split(fieldExpr, ".") + fields := SplitPath(fieldExpr) for _, field := range fields { + if strings.Contains(field, "[") != strings.Contains(field, "]") { + return reflect.Value{}, errors.New("invalid path: unclosed brackets ") + } + v = GetConcreteElement(v) // Handle array/slice indices - if strings.Contains(field, "[") { + if strings.Contains(field, "[") && !strings.Contains(field, ".") { fieldName := field[:strings.Index(field, "[")] indexStr := field[strings.Index(field, "[")+1 : strings.Index(field, "]")] index, err := strconv.Atoi(indexStr) @@ -75,6 +83,7 @@ func GetField(v reflect.Value, fieldExpr string) (reflect.Value, error) { return reflect.Value{}, fmt.Errorf("field is not a slice or array: %s", fieldName) } } else { + field = strings.TrimSuffix(strings.TrimLeft(field, ".["), "]") switch v.Kind() { case reflect.Map: if v.Type().Key().Kind() != reflect.String { @@ -186,3 +195,53 @@ func LastOfType[T any](values []reflect.Value) (T, bool) { // Use FirstOfType on the reversed slice return FirstOfType[T](reversed) } + +// IsAnyOf returns true if the [reflect.Value] is any of the specified types. +func IsAnyOf(v reflect.Value, types ...reflect.Kind) bool { + for _, t := range types { + if v.Kind() == t { + return true + } + } + return false +} + +// SplitPath splits a path string into parts separated by '.' and '[', ']'. +// It is used to split a path string into parts that can be used to access fields in a slice, array, struct, or map. +// Bracketed components are treated as a single part, including the brackets. +func SplitPath(path string) []string { + var parts []string + bracket := 0 + lastPartIdx := 0 + for i := 0; i < len(path); i++ { + switch path[i] { + case '.': + if bracket == 0 { + if i > lastPartIdx { + parts = append(parts, path[lastPartIdx:i]) + } + lastPartIdx = i + } + + case '[': + if bracket == 0 { + if i > lastPartIdx { + parts = append(parts, path[lastPartIdx:i]) + } + lastPartIdx = i + } + bracket++ + + case ']': + bracket-- + if bracket == 0 { + parts = append(parts, path[lastPartIdx:i+1]) + lastPartIdx = i + 1 + } + } + if i == len(path)-1 && lastPartIdx <= i { + parts = append(parts, path[lastPartIdx:]) + } + } + return parts +} diff --git a/pkg/k2/reflectutil/reflectutil_test.go b/pkg/reflectutil/reflectutil_test.go similarity index 92% rename from pkg/k2/reflectutil/reflectutil_test.go rename to pkg/reflectutil/reflectutil_test.go index 7bbfb5845..5d96b9074 100644 --- a/pkg/k2/reflectutil/reflectutil_test.go +++ b/pkg/reflectutil/reflectutil_test.go @@ -1,6 +1,7 @@ package reflectutil import ( + "github.com/stretchr/testify/assert" "reflect" "testing" ) @@ -108,6 +109,17 @@ func TestGetField(t *testing.T) { C: 3.14, } + exampleMapWithPeriods := map[string]any{ + "A.B": map[string]any{ + "C.D": "value", + }, + "E": map[string]any{ + "F.G": []map[string]any{ + {"H.I": "nested value"}, + }, + }, + } + tests := []struct { name string v any @@ -154,7 +166,6 @@ func TestGetField(t *testing.T) { name: "Empty fields", v: exampleEmptyFields, fieldExpr: "A.B[0].C", - want: nil, wantErr: true, }, { @@ -168,21 +179,18 @@ func TestGetField(t *testing.T) { name: "Invalid field name", v: example, fieldExpr: "A.X[0].C", - want: nil, wantErr: true, }, { name: "Index out of range", v: example, fieldExpr: "A.B[1].C", - want: nil, wantErr: true, }, { name: "Field is not slice or array", v: example, fieldExpr: "A.B.C", - want: nil, wantErr: true, }, { @@ -225,18 +233,55 @@ func TestGetField(t *testing.T) { fieldExpr: "A", wantErr: true, }, + { + name: "Map access with period in key", + v: exampleMapWithPeriods, + fieldExpr: `[A.B][C.D]`, + want: "value", + }, + { + name: "Nested map access with period in key", + v: exampleMapWithPeriods, + fieldExpr: `E[F.G][0][H.I]`, + want: "nested value", + }, + { + name: "Mixed dot and bracket notation", + v: exampleMapWithPeriods, + fieldExpr: `E[F.G].0[H.I]`, + want: "nested value", + }, + { + name: "Missing closing bracket", + v: exampleMapWithPeriods, + fieldExpr: `[A.B`, + wantErr: true, + }, + { + name: "Empty brackets", + v: exampleMapWithPeriods, + fieldExpr: `[]`, + wantErr: true, + }, + { + name: "Bracket notation for non-map", + v: example, + fieldExpr: `A[B]`, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) v := reflect.ValueOf(tt.v) got, err := GetField(v, tt.fieldExpr) - if (err != nil) != tt.wantErr { - t.Errorf("GetField() error = %v, wantErr %v", err, tt.wantErr) + + if tt.wantErr { + assert.Errorf(err, "GetField() wantErr %v", tt.wantErr) return - } - if !tt.wantErr && !reflect.DeepEqual(got.Interface(), tt.want) { - t.Errorf("GetField() = %v, want %v", got.Interface(), tt.want) + } else { + assert.EqualValues(tt.want, GetConcreteValue(got)) } }) } diff --git a/pkg/templateutils/funcs.go b/pkg/templateutils/funcs.go index 3a39a7dc2..ba47f9fcd 100644 --- a/pkg/templateutils/funcs.go +++ b/pkg/templateutils/funcs.go @@ -20,8 +20,10 @@ var UtilityFunctions = template.FuncMap{ "zipToMap": ZipToMap, "keysToMapWithDefault": KeysToMapWithDefault, "replace": ReplaceAllRegex, - "hasSuffix": HasSuffix, + "replaceAll": ReplaceAll, + "hasSuffix": strings.HasSuffix, "toLower": strings.ToLower, + "toUpper": strings.ToUpper, "add": Add, "sub": Sub, "last": Last, @@ -29,6 +31,11 @@ var UtilityFunctions = template.FuncMap{ "appendSlice": AppendSlice, "sliceContains": SliceContains, "matches": Matches, + "trimLeft": strings.TrimLeft, + "trimRight": strings.TrimRight, + "trimSpace": strings.TrimSpace, + "trimPrefix": strings.TrimPrefix, + "trimSuffix": strings.TrimSuffix, } func WithCommonFuncs(funcMap template.FuncMap) template.FuncMap { @@ -192,14 +199,14 @@ func ReplaceAllRegex(pattern, replace, value string) (string, error) { return s, nil } -// MakeSlice creates and returns a new slice of any type. -func MakeSlice() []any { - return []any{} +// MakeSlice creates and returns a new slice of any type with the given values. +func MakeSlice(args ...any) []any { + return args } -// AppendSlice appends a value to a slice and returns the updated slice. -func AppendSlice(slice []any, value any) []any { - return append(slice, value) +// AppendSlice appends any number of values to a slice and returns the new slice. +func AppendSlice(slice []any, value ...any) []any { + return append(slice, value...) } // SliceContains checks if a slice contains a specific value. @@ -211,8 +218,3 @@ func SliceContains(slice []any, value any) bool { } return false } - -// HasSuffix checks if a string has a specific suffix. -func HasSuffix(s, suffix string) bool { - return strings.HasSuffix(s, suffix) -} diff --git a/pkg/templateutils/funcs_test.go b/pkg/templateutils/funcs_test.go new file mode 100644 index 000000000..67e885e2b --- /dev/null +++ b/pkg/templateutils/funcs_test.go @@ -0,0 +1,386 @@ +package templateutils + +import ( + "reflect" + "testing" + "text/template" +) + +func TestUtilityFunctions(t *testing.T) { + tests := []struct { + name string + funcName string + args []any + want any + wantErr bool + }{ + // split tests + {"Split basic", "split", []any{"a,b,c", ","}, []string{"a", "b", "c"}, false}, + {"Split no separator", "split", []any{"abc", ","}, []string{"abc"}, false}, + {"Split empty string", "split", []any{"", ","}, []string{""}, false}, + {"Split with empty separator", "split", []any{"abc", ""}, []string{"a", "b", "c"}, false}, + + // join tests + {"Join basic", "join", []any{[]string{"a", "b", "c"}, ","}, "a,b,c", false}, + {"Join empty slice", "join", []any{[]string{}, ","}, "", false}, + {"Join single element", "join", []any{[]string{"a"}, ","}, "a", false}, + {"Join with empty separator", "join", []any{[]string{"a", "b", "c"}, ""}, "abc", false}, + + // basename tests + {"Basename basic", "basename", []any{"/path/to/file.txt"}, "file.txt", false}, + {"Basename no directory", "basename", []any{"file.txt"}, "file.txt", false}, + {"Basename with trailing slash", "basename", []any{"/path/to/directory/"}, "directory", false}, + {"Basename empty string", "basename", []any{""}, ".", false}, + + // filterMatch tests + {"FilterMatch basic", "filterMatch", []any{"^a", []string{"apple", "banana", "avocado"}}, []string{"apple", "avocado"}, false}, + {"FilterMatch no matches", "filterMatch", []any{"^z", []string{"apple", "banana", "avocado"}}, []string{}, false}, + {"FilterMatch empty slice", "filterMatch", []any{"^a", []string{}}, []string{}, false}, + {"FilterMatch invalid regex", "filterMatch", []any{"[", []string{"apple", "banana", "avocado"}}, nil, true}, + + // mapString tests + {"MapString basic", "mapString", []any{"a", "A", []string{"apple", "banana", "avocado"}}, []string{"Apple", "bAnAnA", "AvocAdo"}, false}, + {"MapString no matches", "mapString", []any{"z", "Z", []string{"apple", "banana", "avocado"}}, []string{"apple", "banana", "avocado"}, false}, + {"MapString empty slice", "mapString", []any{"a", "A", []string{}}, []string{}, false}, + {"MapString invalid regex", "mapString", []any{"[", "A", []string{"apple", "banana", "avocado"}}, nil, true}, + + // zipToMap tests + {"ZipToMap basic", "zipToMap", []any{[]string{"a", "b"}, []int{1, 2}}, map[string]any{"a": 1, "b": 2}, false}, + {"ZipToMap empty slices", "zipToMap", []any{[]string{}, []int{}}, map[string]any{}, false}, + {"ZipToMap mismatched lengths", "zipToMap", []any{[]string{"a", "b"}, []int{1}}, nil, true}, + {"ZipToMap non-slice values", "zipToMap", []any{[]string{"a", "b"}, 1}, nil, true}, + + // keysToMapWithDefault tests + {"KeysToMapWithDefault basic", "keysToMapWithDefault", []any{0, []string{"a", "b"}}, map[string]any{"a": 0, "b": 0}, false}, + {"KeysToMapWithDefault empty slice", "keysToMapWithDefault", []any{0, []string{}}, map[string]any{}, false}, + {"KeysToMapWithDefault string default", "keysToMapWithDefault", []any{"default", []string{"a", "b"}}, map[string]any{"a": "default", "b": "default"}, false}, + + // replaceAll tests + {"ReplaceAll basic", "replaceAll", []any{"hello world", "o", "0"}, "hell0 w0rld", false}, + {"ReplaceAll no matches", "replaceAll", []any{"hello world", "z", "0"}, "hello world", false}, + {"ReplaceAll empty string", "replaceAll", []any{"", "o", "0"}, "", false}, + {"ReplaceAll replace with empty", "replaceAll", []any{"hello world", "o", ""}, "hell wrld", false}, + + // hasSuffix tests + {"HasSuffix true", "hasSuffix", []any{"filename.txt", ".txt"}, true, false}, + {"HasSuffix false", "hasSuffix", []any{"filename.txt", ".jpg"}, false, false}, + {"HasSuffix empty suffix", "hasSuffix", []any{"filename.txt", ""}, true, false}, + {"HasSuffix empty string", "hasSuffix", []any{"", ".txt"}, false, false}, + + // toLower tests + {"ToLower basic", "toLower", []any{"Hello World"}, "hello world", false}, + {"ToLower already lowercase", "toLower", []any{"hello world"}, "hello world", false}, + {"ToLower empty string", "toLower", []any{""}, "", false}, + {"ToLower with numbers", "toLower", []any{"HeLLo 123"}, "hello 123", false}, + + // toUpper tests + {"ToUpper basic", "toUpper", []any{"Hello World"}, "HELLO WORLD", false}, + {"ToUpper already uppercase", "toUpper", []any{"HELLO WORLD"}, "HELLO WORLD", false}, + {"ToUpper empty string", "toUpper", []any{""}, "", false}, + {"ToUpper with numbers", "toUpper", []any{"HeLLo 123"}, "HELLO 123", false}, + + // add tests + {"Add basic", "add", []any{1, 2, 3}, 6, false}, + {"Add single number", "add", []any{5}, 5, false}, + {"Add no numbers", "add", []any{}, 0, false}, + {"Add negative numbers", "add", []any{1, -2, 3}, 2, false}, + + // sub tests + {"Sub basic", "sub", []any{10, 3, 2}, 5, false}, + {"Sub single number", "sub", []any{5}, 5, false}, + {"Sub no numbers", "sub", []any{}, 0, false}, + {"Sub negative numbers", "sub", []any{1, -2, 3}, 0, false}, + + // last tests + {"Last basic", "last", []any{[]int{1, 2, 3}}, 3, false}, + {"Last single element", "last", []any{[]int{1}}, 1, false}, + {"Last empty slice", "last", []any{[]int{}}, nil, true}, + {"Last non-slice", "last", []any{1}, nil, true}, + + // makeSlice tests + {"MakeSlice basic", "makeSlice", []any{}, []any{}, false}, + + // appendSlice tests + {"AppendSlice basic", "appendSlice", []any{[]any{1, 2}, 3}, []any{1, 2, 3}, false}, + {"AppendSlice to empty slice", "appendSlice", []any{[]any{}, 1}, []any{1}, false}, + {"AppendSlice different types", "appendSlice", []any{[]any{1, "two"}, 3.0}, []any{1, "two", 3.0}, false}, + + // sliceContains tests + {"SliceContains true", "sliceContains", []any{[]any{1, 2, 3}, 2}, true, false}, + {"SliceContains false", "sliceContains", []any{[]any{1, 2, 3}, 4}, false, false}, + {"SliceContains empty slice", "sliceContains", []any{[]any{}, 1}, false, false}, + {"SliceContains different types", "sliceContains", []any{[]any{1, "two", 3.0}, "two"}, true, false}, + + // matches tests + {"Matches true", "matches", []any{"^a", "apple"}, true, false}, + {"Matches false", "matches", []any{"^b", "apple"}, false, false}, + {"Matches empty string", "matches", []any{".*", ""}, true, false}, + {"Matches invalid regex", "matches", []any{"[", "apple"}, nil, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fn := UtilityFunctions[tt.funcName] + if fn == nil { + t.Fatalf("function %s not found in UtilityFunctions", tt.funcName) + } + + fnValue := reflect.ValueOf(fn) + args := make([]reflect.Value, len(tt.args)) + for i, arg := range tt.args { + args[i] = reflect.ValueOf(arg) + } + + results := fnValue.Call(args) + + if len(results) == 0 { + t.Fatalf("function %s returned no results", tt.funcName) + } + + var got any + var err error + + if len(results) == 2 { + got = results[0].Interface() + err, _ = results[1].Interface().(error) + } else { + got = results[0].Interface() + } + + if (err != nil) != tt.wantErr { + t.Errorf("unexpected error: %v", err) + return + } + + if tt.wantErr { + if err == nil { + t.Errorf("expected an error, but got nil") + } + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +// Additional tests for functions not in UtilityFunctions map + +func TestToJSON(t *testing.T) { + tests := []struct { + name string + input any + want string + wantErr bool + }{ + {"Basic map", map[string]int{"a": 1, "b": 2}, `{"a":1,"b":2}`, false}, + {"Basic slice", []int{1, 2, 3}, `[1,2,3]`, false}, + {"Empty map", map[string]int{}, `{}`, false}, + {"Empty slice", []int{}, `[]`, false}, + {"Nested structure", map[string]any{"a": 1, "b": []int{2, 3}}, `{"a":1,"b":[2,3]}`, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToJSON(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ToJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ToJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToJSONPretty(t *testing.T) { + tests := []struct { + name string + input any + want string + wantErr bool + }{ + {"Basic map", map[string]int{"a": 1, "b": 2}, "{\n \"a\": 1,\n \"b\": 2\n}", false}, + {"Basic slice", []int{1, 2, 3}, "[\n 1,\n 2,\n 3\n]", false}, + {"Empty map", map[string]int{}, "{}", false}, + {"Empty slice", []int{}, "[]", false}, + {"Nested structure", map[string]any{"a": 1, "b": []int{2, 3}}, "{\n \"a\": 1,\n \"b\": [\n 2,\n 3\n ]\n}", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ToJSONPretty(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ToJSONPretty() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ToJSONPretty() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileBase(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"Basic path", "/path/to/file.txt", "file.txt"}, + {"No directory", "file.txt", "file.txt"}, + {"Trailing slash", "/path/to/directory/", "directory"}, + {"Empty string", "", "."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FileBase(tt.input); got != tt.want { + t.Errorf("FileBase() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileTrimExtFunc(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"Basic file", "file.txt", "file"}, + {"No extension", "file", "file"}, + {"Multiple dots", "file.tar.gz", "file.tar"}, + {"Hidden file (not supported)", ".hidden", ""}, // change want to ".hidden" if adding support for hidden files + {"Empty string", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FileTrimExtFunc(tt.input); got != tt.want { + t.Errorf("FileTrimExtFunc() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestFileSep(t *testing.T) { + got := FileSep() + if got != "/" && got != "\\" { + t.Errorf("FileSep() = %v, want either '/' or '\\'", got) + } +} + +func TestReplaceAllRegex(t *testing.T) { + tests := []struct { + name string + pattern string + replace string + value string + want string + wantErr bool + }{ + {"Basic replacement", "a+", "b", "aaa bbb aaa", "b bbb b", false}, + {"No matches", "z+", "b", "aaa bbb aaa", "aaa bbb aaa", false}, + {"Empty string", "a+", "b", "", "", false}, + {"Replace with empty", "a+", "", "aaa bbb aaa", " bbb ", false}, + {"Invalid regex", "[", "b", "aaa bbb aaa", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ReplaceAllRegex(tt.pattern, tt.replace, tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("ReplaceAllRegex() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ReplaceAllRegex() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestWithCommonFuncs tests the WithCommonFuncs function +func TestWithCommonFuncs(t *testing.T) { + // Create a custom FuncMap + customFuncMap := template.FuncMap{ + "customFunc": func() string { return "custom" }, + } + + // Apply WithCommonFuncs + resultFuncMap := WithCommonFuncs(customFuncMap) + + // Check if the custom function is still present + if customFunc, ok := resultFuncMap["customFunc"]; !ok { + t.Errorf("WithCommonFuncs() did not preserve custom function") + } else { + if customFunc.(func() string)() != "custom" { + t.Errorf("WithCommonFuncs() altered custom function behavior") + } + } + + // Check if common functions were added + for funcName := range UtilityFunctions { + if _, ok := resultFuncMap[funcName]; !ok { + t.Errorf("WithCommonFuncs() did not add common function %s", funcName) + } + } + + // Ensure no functions were lost + expectedLength := len(UtilityFunctions) + 1 // +1 for the custom function + if len(resultFuncMap) != expectedLength { + t.Errorf("WithCommonFuncs() resulted in unexpected number of functions. Got %d, want %d", len(resultFuncMap), expectedLength) + } +} + +// Additional helper function to test error cases +func TestErrorCases(t *testing.T) { + errorTests := []struct { + name string + funcName string + args []any + wantErr bool + }{ + {"ZipToMap mismatched lengths", "zipToMap", []any{[]string{"a", "b"}, []int{1}}, true}, + {"ZipToMap non-slice values", "zipToMap", []any{[]string{"a", "b"}, 1}, true}, + {"FilterMatch invalid regex", "filterMatch", []any{"[", []string{"apple", "banana", "avocado"}}, true}, + {"MapString invalid regex", "mapString", []any{"[", "A", []string{"apple", "banana", "avocado"}}, true}, + {"Last empty slice", "last", []any{[]int{}}, true}, + {"Last non-slice", "last", []any{1}, true}, + {"Matches invalid regex", "matches", []any{"[", "apple"}, true}, + } + + for _, tt := range errorTests { + t.Run(tt.name, func(t *testing.T) { + fn := UtilityFunctions[tt.funcName] + if fn == nil { + t.Fatalf("function %s not found in UtilityFunctions", tt.funcName) + } + + fnValue := reflect.ValueOf(fn) + args := make([]reflect.Value, len(tt.args)) + for i, arg := range tt.args { + args[i] = reflect.ValueOf(arg) + } + + results := fnValue.Call(args) + + if len(results) != 2 { + t.Fatalf("expected function %s to return 2 values (result and error)", tt.funcName) + } + + err, ok := results[1].Interface().(error) + if !ok { + t.Fatalf("second return value of function %s is not an error", tt.funcName) + } + + if (err != nil) != tt.wantErr { + t.Errorf("function %s error = %v, wantErr %v", tt.funcName, err, tt.wantErr) + } + }) + } +}